Compare commits
83 Commits
Author | SHA1 | Date |
---|---|---|
pharmapsychotic | bc07ce62c1 | 1 year ago |
pharmapsychotic | 1e47e5149d | 1 year ago |
pharmapsychotic | da516f48ad | 1 year ago |
Harry Wang | 2cf03aaf6e | 2 years ago |
pharmapsychotic | f4429b4c9d | 2 years ago |
pharmapsychotic | 3385e538ee | 2 years ago |
pharmapsychotic | ce9d271aa1 | 2 years ago |
pharmapsychotic | ac74904908 | 2 years ago |
dnwalkup | d2c6e072e4 | 2 years ago |
pharmapsychotic | 9204a33786 | 2 years ago |
pharmapsychotic | 571ba9844c | 2 years ago |
starovoitovs | a0f278c4a9 | 2 years ago |
pharmapsychotic | 80d97f1f96 | 2 years ago |
pharmapsychotic | 08546eae22 | 2 years ago |
pharmapsychotic | 384e234ba2 | 2 years ago |
pharmapsychotic | eecf1864a1 | 2 years ago |
pharmapsychotic | c4e16359a7 | 2 years ago |
pharmapsychotic | fd93edc572 | 2 years ago |
pharmapsychotic | 6a62ce73e8 | 2 years ago |
pharmapsychotic | bf7404d7fa | 2 years ago |
pharmapsychotic | 9f04c8550c | 2 years ago |
pharmapsychotic | ae88b07a65 | 2 years ago |
pharmapsychotic | bcf1833ae0 | 2 years ago |
pharmapsychotic | 93db86fa70 | 2 years ago |
pharmapsychotic | 78287e17e1 | 2 years ago |
pharmapsychotic | 290a63b51e | 2 years ago |
pharmapsychotic | 71c77633a6 | 2 years ago |
pharmapsychotic | 42b3cf4d9e | 2 years ago |
pharmapsychotic | 99c8d45e86 | 2 years ago |
pharmapsychotic | 65c560ffac | 2 years ago |
pharmapsychotic | 55fe80c74c | 2 years ago |
pharmapsychotic | a8ecf52a38 | 2 years ago |
pharmapsychotic | 02576df0ce | 2 years ago |
pharmapsychotic | 1ec6cd9d45 | 2 years ago |
pharmapsychotic | 180cbc4f7b | 2 years ago |
pharmapsychotic | 6f17fb09af | 2 years ago |
pharmapsychotic | e22b005ba5 | 2 years ago |
pharmapsychotic | 152d5f551f | 2 years ago |
pharmapsychotic | 0a0b3968d1 | 2 years ago |
pharmapsychotic | 8b689592aa | 2 years ago |
pharmapsychotic | 2ffcd80b4e | 2 years ago |
pharmapsychotic | 8123696883 | 2 years ago |
pharmapsychotic | 884aab1a26 | 2 years ago |
pharmapsychotic | abbb326f93 | 2 years ago |
pharmapsychotic | faa56c8ef9 | 2 years ago |
pharmapsychotic | 979fca878e | 2 years ago |
pharmapsychotic | 5c4872d1f7 | 2 years ago |
pharmapsychotic | f22be02819 | 2 years ago |
pharmapsychotic | 5aed16b011 | 2 years ago |
pharmapsychotic | e3c1a4df84 | 2 years ago |
pharmapsychotic | 917b7c6c15 | 2 years ago |
pharmapsychotic | 19586d3d0d | 2 years ago |
pharmapsychotic | efee3fe0d7 | 2 years ago |
pharmapsychotic | 1221871c1b | 2 years ago |
pharmapsychotic | ad01cadbef | 2 years ago |
pharmapsychotic | 4bef32e69b | 2 years ago |
pharmapsychotic | 8d2de646b6 | 2 years ago |
pharmapsychotic | 429d490901 | 2 years ago |
pharmapsychotic | 55c922a48a | 2 years ago |
pharmapsychotic | 55b1770386 | 2 years ago |
pharmapsychotic | 8c521c12a0 | 2 years ago |
pharmapsychotic | 6953206901 | 2 years ago |
pharmapsychotic | 31b1d22e82 | 2 years ago |
pharmapsychotic | 8f5ddce2b3 | 2 years ago |
pharmapsychotic | f4abcdfd0c | 2 years ago |
pharmapsychotic | 6139576f88 | 2 years ago |
pharmapsychotic | b62cca2097 | 2 years ago |
pharmapsychotic | 9ce0f68ab3 | 2 years ago |
pharmapsychotic | 7a2ac9aa57 | 2 years ago |
pharmapsychotic | 27b9915dfa | 2 years ago |
pharmapsychotic | 0cdd1e1cb2 | 2 years ago |
pharmapsychotic | 1b5f9437bd | 2 years ago |
pharmapsychotic | b423f7b090 | 2 years ago |
pharmapsychotic | c0a088f9f8 | 2 years ago |
pharmapsychotic | 33fcdc57e3 | 2 years ago |
pharmapsychotic | 803dd389c0 | 2 years ago |
pharmapsychotic | a732d39e20 | 2 years ago |
pharmapsychotic | 30de2e7d85 | 2 years ago |
maple | d45a9c35d7 | 2 years ago |
Chenxi | 11a0087004 | 2 years ago |
pharmapsychotic | 2486589f24 | 2 years ago |
pharmapsychotic | 6a0f6d457d | 2 years ago |
amrrs | b09fcc8b35 | 2 years ago |
20 changed files with 102481 additions and 1112 deletions
@ -0,0 +1,39 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created |
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries |
||||
|
||||
# This workflow uses actions that are not certified by GitHub. |
||||
# They are provided by a third-party and are governed by |
||||
# separate terms of service, privacy policy, and support |
||||
# documentation. |
||||
|
||||
name: Upload Python Package |
||||
|
||||
on: |
||||
release: |
||||
types: [published] |
||||
|
||||
permissions: |
||||
contents: read |
||||
|
||||
jobs: |
||||
deploy: |
||||
|
||||
runs-on: ubuntu-latest |
||||
|
||||
steps: |
||||
- uses: actions/checkout@v3 |
||||
- name: Set up Python |
||||
uses: actions/setup-python@v3 |
||||
with: |
||||
python-version: '3.x' |
||||
- name: Install dependencies |
||||
run: | |
||||
python -m pip install --upgrade pip |
||||
pip install build |
||||
- name: Build package |
||||
run: python -m build |
||||
- name: Publish package |
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 |
||||
with: |
||||
user: pharmapsychotic |
||||
password: ${{ secrets.PYPI_API_TOKEN }} |
@ -0,0 +1,10 @@
|
||||
*.pyc |
||||
.cog/ |
||||
.vscode/ |
||||
bench/ |
||||
cache/ |
||||
ci_env/ |
||||
clip-interrogator/ |
||||
clip_interrogator.egg-info/ |
||||
dist/ |
||||
venv/ |
@ -0,0 +1,6 @@
|
||||
include clip_interrogator/data/artists.txt |
||||
include clip_interrogator/data/flavors.txt |
||||
include clip_interrogator/data/mediums.txt |
||||
include clip_interrogator/data/movements.txt |
||||
include clip_interrogator/data/negative.txt |
||||
include requirements.txt |
@ -1,6 +1,86 @@
|
||||
# clip-interrogator |
||||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) |
||||
*Want to figure out what a good prompt might be to create new images like an existing one? The **CLIP Interrogator** is here to get you answers!* |
||||
|
||||
The CLIP Interrogator uses the OpenAI CLIP models to test a given image against a variety of artists, mediums, and styles to study how the different models see the content of the image. It also combines the results with BLIP caption to suggest a text prompt to create more images similar to what was given. |
||||
## Run it! |
||||
|
||||
🆕 Now available as a [Stable Diffusion Web UI Extension](https://github.com/pharmapsychotic/clip-interrogator-ext)! 🆕 |
||||
|
||||
<br> |
||||
|
||||
Run Version 2 on Colab, HuggingFace, and Replicate! |
||||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator) [![Lambda](https://img.shields.io/badge/%CE%BB-Lambda-blue)](https://cloud.lambdalabs.com/demos/ml/CLIP-Interrogator) |
||||
|
||||
<br> |
||||
|
||||
|
||||
Version 1 still available in Colab for comparing different CLIP models |
||||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) |
||||
|
||||
|
||||
## About |
||||
|
||||
The **CLIP Interrogator** is a prompt engineering tool that combines OpenAI's [CLIP](https://openai.com/blog/clip/) and Salesforce's [BLIP](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/) to optimize text prompts to match a given image. Use the resulting prompts with text-to-image models like [Stable Diffusion](https://github.com/CompVis/stable-diffusion) on [DreamStudio](https://beta.dreamstudio.ai/) to create cool art! |
||||
|
||||
|
||||
## Using as a library |
||||
|
||||
Create and activate a Python virtual environment |
||||
```bash |
||||
python3 -m venv ci_env |
||||
(for linux ) source ci_env/bin/activate |
||||
(for windows) .\ci_env\Scripts\activate |
||||
``` |
||||
|
||||
Install with PIP |
||||
``` |
||||
# install torch with GPU support for example: |
||||
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 |
||||
|
||||
# install clip-interrogator |
||||
pip install clip-interrogator==0.5.4 |
||||
|
||||
# or for very latest WIP with BLIP2 support |
||||
#pip install clip-interrogator==0.6.0 |
||||
``` |
||||
|
||||
You can then use it in your script |
||||
```python |
||||
from PIL import Image |
||||
from clip_interrogator import Config, Interrogator |
||||
image = Image.open(image_path).convert('RGB') |
||||
ci = Interrogator(Config(clip_model_name="ViT-L-14/openai")) |
||||
print(ci.interrogate(image)) |
||||
``` |
||||
|
||||
CLIP Interrogator uses OpenCLIP which supports many different pretrained CLIP models. For the best prompts for |
||||
Stable Diffusion 1.X use `ViT-L-14/openai` for clip_model_name. For Stable Diffusion 2.0 use `ViT-H-14/laion2b_s32b_b79k` |
||||
|
||||
## Configuration |
||||
|
||||
The `Config` object lets you configure CLIP Interrogator's processing. |
||||
* `clip_model_name`: which of the OpenCLIP pretrained CLIP models to use |
||||
* `cache_path`: path where to save precomputed text embeddings |
||||
* `download_cache`: when True will download the precomputed embeddings from huggingface |
||||
* `chunk_size`: batch size for CLIP, use smaller for lower VRAM |
||||
* `quiet`: when True no progress bars or text output will be displayed |
||||
|
||||
On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB. |
||||
|
||||
See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes. |
||||
|
||||
|
||||
## Ranking against your own list of terms (requires version 0.6.0) |
||||
|
||||
```python |
||||
from clip_interrogator import Config, Interrogator, LabelTable, load_list |
||||
from PIL import Image |
||||
|
||||
ci = Interrogator(Config(blip_model_type=None)) |
||||
image = Image.open(image_path).convert('RGB') |
||||
table = LabelTable(load_list('terms.txt'), 'terms', ci) |
||||
best_match = table.rank(ci.image_to_features(image), top_count=1)[0] |
||||
print(best_match) |
||||
``` |
||||
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,4 @@
|
||||
from .clip_interrogator import Config, Interrogator, LabelTable, list_caption_models, list_clip_models, load_list |
||||
|
||||
__version__ = '0.6.0' |
||||
__author__ = 'pharmapsychotic' |
@ -0,0 +1,450 @@
|
||||
import hashlib |
||||
import math |
||||
import numpy as np |
||||
import open_clip |
||||
import os |
||||
import requests |
||||
import time |
||||
import torch |
||||
|
||||
from dataclasses import dataclass |
||||
from PIL import Image |
||||
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration |
||||
from tqdm import tqdm |
||||
from typing import List, Optional |
||||
|
||||
from safetensors.numpy import load_file, save_file |
||||
|
||||
CAPTION_MODELS = { |
||||
'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB |
||||
'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB |
||||
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB |
||||
'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB |
||||
'git-large-coco': 'microsoft/git-large-coco', # 1.58GB |
||||
} |
||||
|
||||
CACHE_URL_BASE = 'https://huggingface.co/pharmapsychotic/ci-preprocess/resolve/main/' |
||||
|
||||
|
||||
@dataclass |
||||
class Config: |
||||
# models can optionally be passed in directly |
||||
caption_model = None |
||||
caption_processor = None |
||||
clip_model = None |
||||
clip_preprocess = None |
||||
|
||||
# blip settings |
||||
caption_max_length: int = 32 |
||||
caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None |
||||
caption_offload: bool = False |
||||
|
||||
# clip settings |
||||
clip_model_name: str = 'ViT-L-14/openai' |
||||
clip_model_path: Optional[str] = None |
||||
clip_offload: bool = False |
||||
|
||||
# interrogator settings |
||||
cache_path: str = 'cache' # path to store cached text embeddings |
||||
download_cache: bool = True # when true, cached embeds are downloaded from huggingface |
||||
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM |
||||
data_path: str = os.path.join(os.path.dirname(__file__), 'data') |
||||
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") |
||||
flavor_intermediate_count: int = 2048 |
||||
quiet: bool = False # when quiet progress bars are not shown |
||||
|
||||
def apply_low_vram_defaults(self): |
||||
self.caption_model_name = 'blip-base' |
||||
self.caption_offload = True |
||||
self.clip_offload = True |
||||
self.chunk_size = 1024 |
||||
self.flavor_intermediate_count = 1024 |
||||
|
||||
class Interrogator(): |
||||
def __init__(self, config: Config): |
||||
self.config = config |
||||
self.device = config.device |
||||
self.dtype = torch.float16 if self.device == 'cuda' else torch.float32 |
||||
self.caption_offloaded = True |
||||
self.clip_offloaded = True |
||||
self.load_caption_model() |
||||
self.load_clip_model() |
||||
|
||||
def load_caption_model(self): |
||||
if self.config.caption_model is None and self.config.caption_model_name: |
||||
if not self.config.quiet: |
||||
print(f"Loading caption model {self.config.caption_model_name}...") |
||||
|
||||
model_path = CAPTION_MODELS[self.config.caption_model_name] |
||||
if self.config.caption_model_name.startswith('git-'): |
||||
caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32) |
||||
elif self.config.caption_model_name.startswith('blip2-'): |
||||
caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype) |
||||
else: |
||||
caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype) |
||||
self.caption_processor = AutoProcessor.from_pretrained(model_path) |
||||
|
||||
caption_model.eval() |
||||
if not self.config.caption_offload: |
||||
caption_model = caption_model.to(self.config.device) |
||||
self.caption_model = caption_model |
||||
else: |
||||
self.caption_model = self.config.caption_model |
||||
self.caption_processor = self.config.caption_processor |
||||
|
||||
def load_clip_model(self): |
||||
start_time = time.time() |
||||
config = self.config |
||||
|
||||
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) |
||||
|
||||
if config.clip_model is None: |
||||
if not config.quiet: |
||||
print(f"Loading CLIP model {config.clip_model_name}...") |
||||
|
||||
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( |
||||
clip_model_name, |
||||
pretrained=clip_model_pretrained_name, |
||||
precision='fp16' if config.device == 'cuda' else 'fp32', |
||||
device=config.device, |
||||
jit=False, |
||||
cache_dir=config.clip_model_path |
||||
) |
||||
self.clip_model.eval() |
||||
else: |
||||
self.clip_model = config.clip_model |
||||
self.clip_preprocess = config.clip_preprocess |
||||
self.tokenize = open_clip.get_tokenizer(clip_model_name) |
||||
|
||||
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribbble', |
||||
'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', |
||||
'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] |
||||
trending_list = [site for site in sites] |
||||
trending_list.extend(["trending on "+site for site in sites]) |
||||
trending_list.extend(["featured on "+site for site in sites]) |
||||
trending_list.extend([site+" contest winner" for site in sites]) |
||||
|
||||
raw_artists = load_list(config.data_path, 'artists.txt') |
||||
artists = [f"by {a}" for a in raw_artists] |
||||
artists.extend([f"inspired by {a}" for a in raw_artists]) |
||||
|
||||
self._prepare_clip() |
||||
self.artists = LabelTable(artists, "artists", self) |
||||
self.flavors = LabelTable(load_list(config.data_path, 'flavors.txt'), "flavors", self) |
||||
self.mediums = LabelTable(load_list(config.data_path, 'mediums.txt'), "mediums", self) |
||||
self.movements = LabelTable(load_list(config.data_path, 'movements.txt'), "movements", self) |
||||
self.trendings = LabelTable(trending_list, "trendings", self) |
||||
self.negative = LabelTable(load_list(config.data_path, 'negative.txt'), "negative", self) |
||||
|
||||
end_time = time.time() |
||||
if not config.quiet: |
||||
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") |
||||
|
||||
def chain( |
||||
self, |
||||
image_features: torch.Tensor, |
||||
phrases: List[str], |
||||
best_prompt: str="", |
||||
best_sim: float=0, |
||||
min_count: int=8, |
||||
max_count: int=32, |
||||
desc="Chaining", |
||||
reverse: bool=False |
||||
) -> str: |
||||
self._prepare_clip() |
||||
|
||||
phrases = set(phrases) |
||||
if not best_prompt: |
||||
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse) |
||||
best_sim = self.similarity(image_features, best_prompt) |
||||
phrases.remove(best_prompt) |
||||
curr_prompt, curr_sim = best_prompt, best_sim |
||||
|
||||
def check(addition: str, idx: int) -> bool: |
||||
nonlocal best_prompt, best_sim, curr_prompt, curr_sim |
||||
prompt = curr_prompt + ", " + addition |
||||
sim = self.similarity(image_features, prompt) |
||||
if reverse: |
||||
sim = -sim |
||||
|
||||
if sim > best_sim: |
||||
best_prompt, best_sim = prompt, sim |
||||
if sim > curr_sim or idx < min_count: |
||||
curr_prompt, curr_sim = prompt, sim |
||||
return True |
||||
return False |
||||
|
||||
for idx in tqdm(range(max_count), desc=desc, disable=self.config.quiet): |
||||
best = self.rank_top(image_features, [f"{curr_prompt}, {f}" for f in phrases], reverse=reverse) |
||||
flave = best[len(curr_prompt)+2:] |
||||
if not check(flave, idx): |
||||
break |
||||
if _prompt_at_max_len(curr_prompt, self.tokenize): |
||||
break |
||||
phrases.remove(flave) |
||||
|
||||
return best_prompt |
||||
|
||||
def generate_caption(self, pil_image: Image) -> str: |
||||
assert self.caption_model is not None, "No caption model loaded." |
||||
self._prepare_caption() |
||||
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device) |
||||
if not self.config.caption_model_name.startswith('git-'): |
||||
inputs = inputs.to(self.dtype) |
||||
tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length) |
||||
return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip() |
||||
|
||||
def image_to_features(self, image: Image) -> torch.Tensor: |
||||
self._prepare_clip() |
||||
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
||||
with torch.no_grad(), torch.cuda.amp.autocast(): |
||||
image_features = self.clip_model.encode_image(images) |
||||
image_features /= image_features.norm(dim=-1, keepdim=True) |
||||
return image_features |
||||
|
||||
def interrogate_classic(self, image: Image, max_flavors: int=3, caption: Optional[str]=None) -> str: |
||||
"""Classic mode creates a prompt in a standard format first describing the image, |
||||
then listing the artist, trending, movement, and flavor text modifiers.""" |
||||
caption = caption or self.generate_caption(image) |
||||
image_features = self.image_to_features(image) |
||||
|
||||
medium = self.mediums.rank(image_features, 1)[0] |
||||
artist = self.artists.rank(image_features, 1)[0] |
||||
trending = self.trendings.rank(image_features, 1)[0] |
||||
movement = self.movements.rank(image_features, 1)[0] |
||||
flaves = ", ".join(self.flavors.rank(image_features, max_flavors)) |
||||
|
||||
if caption.startswith(medium): |
||||
prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}" |
||||
else: |
||||
prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" |
||||
|
||||
return _truncate_to_fit(prompt, self.tokenize) |
||||
|
||||
def interrogate_fast(self, image: Image, max_flavors: int=32, caption: Optional[str]=None) -> str: |
||||
"""Fast mode simply adds the top ranked terms after a caption. It generally results in |
||||
better similarity between generated prompt and image than classic mode, but the prompts |
||||
are less readable.""" |
||||
caption = caption or self.generate_caption(image) |
||||
image_features = self.image_to_features(image) |
||||
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self) |
||||
tops = merged.rank(image_features, max_flavors) |
||||
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) |
||||
|
||||
def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str: |
||||
"""Negative mode chains together the most dissimilar terms to the image. It can be used |
||||
to help build a negative prompt to pair with the regular positive prompt and often |
||||
improve the results of generated images particularly with Stable Diffusion 2.""" |
||||
image_features = self.image_to_features(image) |
||||
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count, reverse=True) |
||||
flaves = flaves + self.negative.labels |
||||
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain") |
||||
|
||||
def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32, caption: Optional[str]=None) -> str: |
||||
caption = caption or self.generate_caption(image) |
||||
image_features = self.image_to_features(image) |
||||
|
||||
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self) |
||||
flaves = merged.rank(image_features, self.config.flavor_intermediate_count) |
||||
best_prompt, best_sim = caption, self.similarity(image_features, caption) |
||||
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain") |
||||
|
||||
fast_prompt = self.interrogate_fast(image, max_flavors, caption=caption) |
||||
classic_prompt = self.interrogate_classic(image, max_flavors, caption=caption) |
||||
candidates = [caption, classic_prompt, fast_prompt, best_prompt] |
||||
return candidates[np.argmax(self.similarities(image_features, candidates))] |
||||
|
||||
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: |
||||
self._prepare_clip() |
||||
text_tokens = self.tokenize([text for text in text_array]).to(self.device) |
||||
with torch.no_grad(), torch.cuda.amp.autocast(): |
||||
text_features = self.clip_model.encode_text(text_tokens) |
||||
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||
similarity = text_features @ image_features.T |
||||
if reverse: |
||||
similarity = -similarity |
||||
return text_array[similarity.argmax().item()] |
||||
|
||||
def similarity(self, image_features: torch.Tensor, text: str) -> float: |
||||
self._prepare_clip() |
||||
text_tokens = self.tokenize([text]).to(self.device) |
||||
with torch.no_grad(), torch.cuda.amp.autocast(): |
||||
text_features = self.clip_model.encode_text(text_tokens) |
||||
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||
similarity = text_features @ image_features.T |
||||
return similarity[0][0].item() |
||||
|
||||
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]: |
||||
self._prepare_clip() |
||||
text_tokens = self.tokenize([text for text in text_array]).to(self.device) |
||||
with torch.no_grad(), torch.cuda.amp.autocast(): |
||||
text_features = self.clip_model.encode_text(text_tokens) |
||||
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||
similarity = text_features @ image_features.T |
||||
return similarity.T[0].tolist() |
||||
|
||||
def _prepare_caption(self): |
||||
if self.config.clip_offload and not self.clip_offloaded: |
||||
self.clip_model = self.clip_model.to('cpu') |
||||
self.clip_offloaded = True |
||||
if self.caption_offloaded: |
||||
self.caption_model = self.caption_model.to(self.device) |
||||
self.caption_offloaded = False |
||||
|
||||
def _prepare_clip(self): |
||||
if self.config.caption_offload and not self.caption_offloaded: |
||||
self.caption_model = self.caption_model.to('cpu') |
||||
self.caption_offloaded = True |
||||
if self.clip_offloaded: |
||||
self.clip_model = self.clip_model.to(self.device) |
||||
self.clip_offloaded = False |
||||
|
||||
|
||||
class LabelTable(): |
||||
def __init__(self, labels:List[str], desc:str, ci: Interrogator): |
||||
clip_model, config = ci.clip_model, ci.config |
||||
self.chunk_size = config.chunk_size |
||||
self.config = config |
||||
self.device = config.device |
||||
self.embeds = [] |
||||
self.labels = labels |
||||
self.tokenize = ci.tokenize |
||||
|
||||
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() |
||||
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_') |
||||
self._load_cached(desc, hash, sanitized_name) |
||||
|
||||
if len(self.labels) != len(self.embeds): |
||||
self.embeds = [] |
||||
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) |
||||
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): |
||||
text_tokens = self.tokenize(chunk).to(self.device) |
||||
with torch.no_grad(), torch.cuda.amp.autocast(): |
||||
text_features = clip_model.encode_text(text_tokens) |
||||
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||
text_features = text_features.half().cpu().numpy() |
||||
for i in range(text_features.shape[0]): |
||||
self.embeds.append(text_features[i]) |
||||
|
||||
if desc and self.config.cache_path: |
||||
os.makedirs(self.config.cache_path, exist_ok=True) |
||||
cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") |
||||
tensors = { |
||||
"embeds": np.stack(self.embeds), |
||||
"hash": np.array([ord(c) for c in hash], dtype=np.int8) |
||||
} |
||||
save_file(tensors, cache_filepath) |
||||
|
||||
if self.device == 'cpu' or self.device == torch.device('cpu'): |
||||
self.embeds = [e.astype(np.float32) for e in self.embeds] |
||||
|
||||
def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool: |
||||
if self.config.cache_path is None or desc is None: |
||||
return False |
||||
|
||||
cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") |
||||
|
||||
if self.config.download_cache and not os.path.exists(cached_safetensors): |
||||
download_url = CACHE_URL_BASE + f"{sanitized_name}_{desc}.safetensors" |
||||
try: |
||||
os.makedirs(self.config.cache_path, exist_ok=True) |
||||
_download_file(download_url, cached_safetensors, quiet=self.config.quiet) |
||||
except Exception as e: |
||||
print(f"Failed to download {download_url}") |
||||
print(e) |
||||
return False |
||||
|
||||
if os.path.exists(cached_safetensors): |
||||
try: |
||||
tensors = load_file(cached_safetensors) |
||||
except Exception as e: |
||||
print(f"Failed to load {cached_safetensors}") |
||||
print(e) |
||||
return False |
||||
if 'hash' in tensors and 'embeds' in tensors: |
||||
if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)): |
||||
self.embeds = tensors['embeds'] |
||||
if len(self.embeds.shape) == 2: |
||||
self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])] |
||||
return True |
||||
|
||||
return False |
||||
|
||||
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: |
||||
top_count = min(top_count, len(text_embeds)) |
||||
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) |
||||
with torch.cuda.amp.autocast(): |
||||
similarity = image_features @ text_embeds.T |
||||
if reverse: |
||||
similarity = -similarity |
||||
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1) |
||||
return [top_labels[0][i].numpy() for i in range(top_count)] |
||||
|
||||
def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]: |
||||
if len(self.labels) <= self.chunk_size: |
||||
tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse) |
||||
return [self.labels[i] for i in tops] |
||||
|
||||
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size)) |
||||
keep_per_chunk = int(self.chunk_size / num_chunks) |
||||
|
||||
top_labels, top_embeds = [], [] |
||||
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): |
||||
start = chunk_idx*self.chunk_size |
||||
stop = min(start+self.chunk_size, len(self.embeds)) |
||||
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse) |
||||
top_labels.extend([self.labels[start+i] for i in tops]) |
||||
top_embeds.extend([self.embeds[start+i] for i in tops]) |
||||
|
||||
tops = self._rank(image_features, top_embeds, top_count=top_count) |
||||
return [top_labels[i] for i in tops] |
||||
|
||||
|
||||
def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet: bool = False): |
||||
r = requests.get(url, stream=True) |
||||
if r.status_code != 200: |
||||
return |
||||
|
||||
file_size = int(r.headers.get("Content-Length", 0)) |
||||
filename = url.split("/")[-1] |
||||
progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet) |
||||
with open(filepath, "wb") as f: |
||||
for chunk in r.iter_content(chunk_size=chunk_size): |
||||
if chunk: |
||||
f.write(chunk) |
||||
progress.update(len(chunk)) |
||||
progress.close() |
||||
|
||||
def _merge_tables(tables: List[LabelTable], ci: Interrogator) -> LabelTable: |
||||
m = LabelTable([], None, ci) |
||||
for table in tables: |
||||
m.labels.extend(table.labels) |
||||
m.embeds.extend(table.embeds) |
||||
return m |
||||
|
||||
def _prompt_at_max_len(text: str, tokenize) -> bool: |
||||
tokens = tokenize([text]) |
||||
return tokens[0][-1] != 0 |
||||
|
||||
def _truncate_to_fit(text: str, tokenize) -> str: |
||||
parts = text.split(', ') |
||||
new_text = parts[0] |
||||
for part in parts[1:]: |
||||
if _prompt_at_max_len(new_text + part, tokenize): |
||||
break |
||||
new_text += ', ' + part |
||||
return new_text |
||||
|
||||
def list_caption_models() -> List[str]: |
||||
return list(CAPTION_MODELS.keys()) |
||||
|
||||
def list_clip_models() -> List[str]: |
||||
return ['/'.join(x) for x in open_clip.list_pretrained()] |
||||
|
||||
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]: |
||||
"""Load a list of strings from a file.""" |
||||
if filename is not None: |
||||
data_path = os.path.join(data_path, filename) |
||||
with open(data_path, 'r', encoding='utf-8', errors='replace') as f: |
||||
items = [line.strip() for line in f.readlines()] |
||||
return items |
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,41 @@
|
||||
3d |
||||
b&w |
||||
bad anatomy |
||||
bad art |
||||
blur |
||||
blurry |
||||
cartoon |
||||
childish |
||||
close up |
||||
deformed |
||||
disconnected limbs |
||||
disfigured |
||||
disgusting |
||||
extra limb |
||||
extra limbs |
||||
floating limbs |
||||
grain |
||||
illustration |
||||
kitsch |
||||
long body |
||||
long neck |
||||
low quality |
||||
low-res |
||||
malformed hands |
||||
mangled |
||||
missing limb |
||||
mutated |
||||
mutation |
||||
mutilated |
||||
noisy |
||||
old |
||||
out of focus |
||||
over saturation |
||||
oversaturated |
||||
poorly drawn |
||||
poorly drawn face |
||||
poorly drawn hands |
||||
render |
||||
surreal |
||||
ugly |
||||
weird colors |
@ -0,0 +1,16 @@
|
||||
build: |
||||
gpu: true |
||||
cuda: "11.8" |
||||
python_version: "3.10" |
||||
system_packages: |
||||
- "libgl1-mesa-glx" |
||||
- "libglib2.0-0" |
||||
python_packages: |
||||
- "Pillow==10.0.0" |
||||
- "safetensors==0.3.3" |
||||
- "tqdm==4.66.1" |
||||
- "open_clip_torch==2.20.0" |
||||
- "accelerate==0.22.0" |
||||
- "transformers==4.33.1" |
||||
|
||||
predict: "predict.py:Predictor" |
@ -1,397 +0,0 @@
|
||||
#film |
||||
#myportfolio |
||||
#pixelart |
||||
#screenshotsaturday |
||||
#vfxfriday |
||||
1920s |
||||
1970s |
||||
1990s |
||||
20 megapixels |
||||
2d |
||||
2d game art |
||||
32k uhd |
||||
35mm lens |
||||
3840x2160 |
||||
3d |
||||
4k |
||||
8k |
||||
8k 3d |
||||
8k resolution |
||||
I can't believe how beautiful this is |
||||
academic art |
||||
acrylic art |
||||
adafruit |
||||
aesthetic |
||||
aftereffects |
||||
airbrush art |
||||
ambient occlusion |
||||
ambrotype |
||||
american propaganda |
||||
anaglyph effect |
||||
anaglyph filter |
||||
anamorphic lens flare |
||||
androgynous |
||||
angelic photograph |
||||
angular |
||||
anime |
||||
anime aesthetic |
||||
antichrist |
||||
apocalypse art |
||||
apocalypse landscape |
||||
art |
||||
art deco |
||||
art on instagram |
||||
artstation hd |
||||
artstation hq |
||||
artwork |
||||
associated press photo |
||||
atmospheric |
||||
award winning |
||||
award-winning |
||||
backlight |
||||
beautiful |
||||
behance hd |
||||
bioluminescence |
||||
biomorphic |
||||
black and white |
||||
black background |
||||
blueprint |
||||
bob ross |
||||
bokeh |
||||
booru |
||||
bryce 3d |
||||
calotype |
||||
chalk art |
||||
character |
||||
charcoal drawing |
||||
chiaroscuro |
||||
childs drawing |
||||
chillwave |
||||
chromatic |
||||
cinematic |
||||
cinematic lighting |
||||
cinematic view |
||||
circuitry |
||||
cityscape |
||||
clean |
||||
close up |
||||
cluttered |
||||
colorful |
||||
colorized |
||||
commission for |
||||
complementary colors |
||||
concept art |
||||
concert poster |
||||
congruent |
||||
constructivism |
||||
contest winner |
||||
contrasting |
||||
cosmic horror |
||||
creative commons attribution |
||||
creepypasta |
||||
criterion collection |
||||
cryengine |
||||
cubism |
||||
cyanotype |
||||
d&d |
||||
da vinci |
||||
dark |
||||
dark and mysterious |
||||
darksynth |
||||
datamosh |
||||
daz3d |
||||
dc comics |
||||
demonic photograph |
||||
depth of field |
||||
destructive |
||||
detailed |
||||
detailed painting |
||||
deviantart |
||||
deviantart hd |
||||
digital illustration |
||||
digital painting |
||||
digitally enhanced |
||||
diorama |
||||
dramatic |
||||
dramatic lighting |
||||
dslr |
||||
dslr camera |
||||
dutch golden age |
||||
dye-transfer |
||||
dynamic composition |
||||
dynamic pose |
||||
dystopian art |
||||
egyptian art |
||||
elegant |
||||
elite |
||||
enchanting |
||||
epic |
||||
ethereal |
||||
extremely gendered |
||||
fantasy |
||||
fauvism |
||||
feminine |
||||
film grain |
||||
filmic |
||||
fine art |
||||
fisheye lens |
||||
flat colors |
||||
flat shading |
||||
flemish baroque |
||||
flickering light |
||||
flickr |
||||
fractalism |
||||
freakshow |
||||
fresco |
||||
full body |
||||
full of details |
||||
furaffinity |
||||
future tech |
||||
futuristic |
||||
genderless |
||||
geometric |
||||
glitch art |
||||
glitchy |
||||
glitter |
||||
global illumination |
||||
glorious |
||||
glowing lights |
||||
glowing neon |
||||
god rays |
||||
golden ratio |
||||
goth |
||||
gothic |
||||
greeble |
||||
groovy |
||||
grotesque |
||||
hall of mirrors |
||||
handsome |
||||
hard surface modeling |
||||
hd |
||||
hd mod |
||||
hdr |
||||
hellish |
||||
hellish background |
||||
henry moore |
||||
high contrast |
||||
high definition |
||||
high detail |
||||
high detailed |
||||
high dynamic range |
||||
high quality |
||||
high quality photo |
||||
high resolution |
||||
holographic |
||||
horror film |
||||
hyper realism |
||||
hyper-realistic |
||||
hypnotic |
||||
ilford hp5 |
||||
ilya kuvshinov |
||||
imax |
||||
impressionism |
||||
infrared |
||||
ink drawing |
||||
inspirational |
||||
instax |
||||
intricate |
||||
intricate patterns |
||||
iridescent |
||||
irridescent |
||||
iso 200 |
||||
isometric |
||||
kinetic |
||||
kodak ektar |
||||
kodak gold 200 |
||||
kodak portra |
||||
lighthearted |
||||
logo |
||||
lomo |
||||
long exposure |
||||
long lens |
||||
lovecraftian |
||||
lovely |
||||
low contrast |
||||
low poly |
||||
lowbrow |
||||
luminescence |
||||
macabre |
||||
macro lens |
||||
macro photography |
||||
made of all of the above |
||||
made of beads and yarn |
||||
made of cardboard |
||||
made of cheese |
||||
made of crystals |
||||
made of feathers |
||||
made of flowers |
||||
made of glass |
||||
made of insects |
||||
made of liquid metal |
||||
made of mist |
||||
made of paperclips |
||||
made of plastic |
||||
made of rubber |
||||
made of trash |
||||
made of vines |
||||
made of wire |
||||
made of wrought iron |
||||
majestic |
||||
marble sculpture |
||||
marvel comics |
||||
masculine |
||||
masterpiece |
||||
matte background |
||||
matte drawing |
||||
matte painting |
||||
matte photo |
||||
maximalist |
||||
messy |
||||
minimalist |
||||
minimalistic |
||||
mist |
||||
mixed media |
||||
movie poster |
||||
movie still |
||||
multiple exposure |
||||
muted |
||||
mystical |
||||
national geographic photo |
||||
neon |
||||
nightmare |
||||
nightscape |
||||
octane render |
||||
official art |
||||
oil on canvas |
||||
ominous |
||||
ominous vibe |
||||
ornate |
||||
orthogonal |
||||
outlined art |
||||
outrun |
||||
painterly |
||||
panorama |
||||
parallax |
||||
pencil sketch |
||||
phallic |
||||
photo |
||||
photo taken with ektachrome |
||||
photo taken with fujifilm superia |
||||
photo taken with nikon d750 |
||||
photo taken with provia |
||||
photocollage |
||||
photocopy |
||||
photoillustration |
||||
photorealistic |
||||
physically based rendering |
||||
picasso |
||||
pixel perfect |
||||
pixiv |
||||
playstation 5 screenshot |
||||
polished |
||||
polycount |
||||
pop art |
||||
post processing |
||||
poster art |
||||
pre-raphaelite |
||||
prerendered graphics |
||||
pretty |
||||
provia |
||||
ps1 graphics |
||||
psychedelic |
||||
quantum wavetracing |
||||
ray tracing |
||||
realism |
||||
redshift |
||||
reimagined by industrial light and magic |
||||
renaissance painting |
||||
rendered in cinema4d |
||||
rendered in maya |
||||
rendered in unreal engine |
||||
repeating pattern |
||||
retrowave |
||||
rich color palette |
||||
rim light |
||||
rococo |
||||
rough |
||||
rtx |
||||
rtx on |
||||
sabattier effect |
||||
sabattier filter |
||||
sanctuary |
||||
sci-fi |
||||
seapunk |
||||
sense of awe |
||||
sensual |
||||
shallow depth of field |
||||
sharp focus |
||||
shiny |
||||
shiny eyes |
||||
shot on 70mm |
||||
sketchfab |
||||
skeuomorphic |
||||
smokey background |
||||
smooth |
||||
soft light |
||||
soft mist |
||||
soviet propaganda |
||||
speedpainting |
||||
stained glass |
||||
steampunk |
||||
stipple |
||||
stock photo |
||||
stockphoto |
||||
storybook illustration |
||||
strange |
||||
streetscape |
||||
studio light |
||||
studio lighting |
||||
studio photography |
||||
studio portrait |
||||
stylish |
||||
sunrays shine upon it |
||||
surrealist |
||||
symmetrical |
||||
synthwave |
||||
tarot card |
||||
tattoo |
||||
telephoto lens |
||||
terragen |
||||
tesseract |
||||
thx sound |
||||
tilt shift |
||||
tintype photograph |
||||
toonami |
||||
trance compilation cd |
||||
trypophobia |
||||
ue5 |
||||
uhd image |
||||
ukiyo-e |
||||
ultra detailed |
||||
ultra hd |
||||
ultra realistic |
||||
ultrafine detail |
||||
unreal engine |
||||
unreal engine 5 |
||||
vaporwave |
||||
velvia |
||||
vibrant colors |
||||
vivid colors |
||||
volumetric lighting |
||||
voxel art |
||||
vray |
||||
vray tracing |
||||
wallpaper |
||||
watercolor |
||||
wavy |
||||
whimsical |
||||
white background |
||||
wiccan |
||||
wide lens |
||||
wimmelbilder |
||||
windows vista |
||||
windows xp |
||||
woodcut |
||||
xbox 360 graphics |
||||
y2k aesthetic |
||||
zbrush |
@ -0,0 +1,45 @@
|
||||
import sys |
||||
from PIL import Image |
||||
from cog import BasePredictor, Input, Path |
||||
|
||||
from clip_interrogator import Config, Interrogator |
||||
|
||||
|
||||
class Predictor(BasePredictor): |
||||
def setup(self): |
||||
self.ci = Interrogator(Config( |
||||
clip_model_name="ViT-L-14/openai", |
||||
clip_model_path='cache', |
||||
device='cuda:0', |
||||
)) |
||||
|
||||
def predict( |
||||
self, |
||||
image: Path = Input(description="Input image"), |
||||
clip_model_name: str = Input( |
||||
default="ViT-L-14/openai", |
||||
choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k", "ViT-bigG-14/laion2b_s39b_b160k"], |
||||
description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.", |
||||
), |
||||
mode: str = Input( |
||||
default="best", |
||||
choices=["best", "classic", "fast", "negative"], |
||||
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).", |
||||
), |
||||
) -> str: |
||||
"""Run a single prediction on the model""" |
||||
image = Image.open(str(image)).convert("RGB") |
||||
self.switch_model(clip_model_name) |
||||
if mode == 'best': |
||||
return self.ci.interrogate(image) |
||||
elif mode == 'classic': |
||||
return self.ci.interrogate_classic(image) |
||||
elif mode == 'fast': |
||||
return self.ci.interrogate_fast(image) |
||||
elif mode == 'negative': |
||||
return self.ci.interrogate_negative(image) |
||||
|
||||
def switch_model(self, clip_model_name: str): |
||||
if clip_model_name != self.ci.config.clip_model_name: |
||||
self.ci.config.clip_model_name = clip_model_name |
||||
self.ci.load_clip_model() |
@ -0,0 +1,3 @@
|
||||
[build-system] |
||||
requires = ["setuptools"] |
||||
build-backend = "setuptools.build_meta" |
@ -0,0 +1,9 @@
|
||||
torch>=1.13.0 |
||||
torchvision |
||||
Pillow |
||||
requests |
||||
safetensors |
||||
tqdm |
||||
open_clip_torch |
||||
accelerate |
||||
transformers>=4.27.1 |
@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3 |
||||
import argparse |
||||
import csv |
||||
import os |
||||
import requests |
||||
import torch |
||||
from PIL import Image |
||||
from clip_interrogator import Interrogator, Config, list_clip_models |
||||
|
||||
def inference(ci, image, mode): |
||||
image = image.convert('RGB') |
||||
if mode == 'best': |
||||
return ci.interrogate(image) |
||||
elif mode == 'classic': |
||||
return ci.interrogate_classic(image) |
||||
else: |
||||
return ci.interrogate_fast(image) |
||||
|
||||
def main(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('-c', '--clip', default='ViT-L-14/openai', help='name of CLIP model to use') |
||||
parser.add_argument('-d', '--device', default='auto', help='device to use (auto, cuda or cpu)') |
||||
parser.add_argument('-f', '--folder', help='path to folder of images') |
||||
parser.add_argument('-i', '--image', help='image file or url') |
||||
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') |
||||
parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM") |
||||
|
||||
args = parser.parse_args() |
||||
if not args.folder and not args.image: |
||||
parser.print_help() |
||||
exit(1) |
||||
|
||||
if args.folder is not None and args.image is not None: |
||||
print("Specify a folder or batch processing or a single image, not both") |
||||
exit(1) |
||||
|
||||
# validate clip model name |
||||
models = list_clip_models() |
||||
if args.clip not in models: |
||||
print(f"Could not find CLIP model {args.clip}!") |
||||
print(f" available models: {models}") |
||||
exit(1) |
||||
|
||||
# select device |
||||
if args.device == 'auto': |
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||
if not torch.cuda.is_available(): |
||||
print("CUDA is not available, using CPU. Warning: this will be very slow!") |
||||
else: |
||||
device = torch.device(args.device) |
||||
|
||||
# generate a nice prompt |
||||
config = Config(device=device, clip_model_name=args.clip) |
||||
if args.lowvram: |
||||
config.apply_low_vram_defaults() |
||||
ci = Interrogator(config) |
||||
|
||||
# process single image |
||||
if args.image is not None: |
||||
image_path = args.image |
||||
if str(image_path).startswith('http://') or str(image_path).startswith('https://'): |
||||
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB') |
||||
else: |
||||
image = Image.open(image_path).convert('RGB') |
||||
if not image: |
||||
print(f'Error opening image {image_path}') |
||||
exit(1) |
||||
print(inference(ci, image, args.mode)) |
||||
|
||||
# process folder of images |
||||
elif args.folder is not None: |
||||
if not os.path.exists(args.folder): |
||||
print(f'The folder {args.folder} does not exist!') |
||||
exit(1) |
||||
|
||||
files = [f for f in os.listdir(args.folder) if f.endswith('.jpg') or f.endswith('.png')] |
||||
prompts = [] |
||||
for file in files: |
||||
image = Image.open(os.path.join(args.folder, file)).convert('RGB') |
||||
prompt = inference(ci, image, args.mode) |
||||
prompts.append(prompt) |
||||
print(prompt) |
||||
|
||||
if len(prompts): |
||||
csv_path = os.path.join(args.folder, 'desc.csv') |
||||
with open(csv_path, 'w', encoding='utf-8', newline='') as f: |
||||
w = csv.writer(f, quoting=csv.QUOTE_MINIMAL) |
||||
w.writerow(['image', 'prompt']) |
||||
for file, prompt in zip(files, prompts): |
||||
w.writerow([file, prompt]) |
||||
|
||||
print(f"\n\n\n\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!") |
||||
|
||||
if __name__ == "__main__": |
||||
main() |
@ -0,0 +1,34 @@
|
||||
import os |
||||
|
||||
import pkg_resources |
||||
from setuptools import setup, find_packages |
||||
|
||||
setup( |
||||
name="clip-interrogator", |
||||
version="0.6.0", |
||||
license='MIT', |
||||
author='pharmapsychotic', |
||||
author_email='me@pharmapsychotic.com', |
||||
url='https://github.com/pharmapsychotic/clip-interrogator', |
||||
description="Generate a prompt from an image", |
||||
long_description=open('README.md', encoding='utf-8').read(), |
||||
long_description_content_type="text/markdown", |
||||
packages=find_packages(), |
||||
install_requires=[ |
||||
str(r) |
||||
for r in pkg_resources.parse_requirements( |
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) |
||||
) |
||||
], |
||||
include_package_data=True, |
||||
extras_require={'dev': ['pytest']}, |
||||
classifiers=[ |
||||
'Intended Audience :: Developers', |
||||
'Intended Audience :: Science/Research', |
||||
'License :: OSI Approved :: MIT License', |
||||
'Topic :: Education', |
||||
'Topic :: Scientific/Engineering', |
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence', |
||||
], |
||||
keywords=['blip','clip','prompt-engineering','stable-diffusion','text-to-image'], |
||||
) |
Loading…
Reference in new issue