Browse Source

Merge pull request #18 from pharmapsychotic/open-clip-switch

Switch to OpenCLIP to support ViT-H
pull/22/head
pharmapsychotic 2 years ago committed by GitHub
parent
commit
5c4872d1f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      README.md
  2. 46
      clip_interrogator.ipynb
  3. 134
      clip_interrogator/clip_interrogator.py
  4. 13
      cog.yaml
  5. 34
      predict.py
  6. 1
      requirements.txt
  7. 9
      run_cli.py
  8. 20
      run_gradio.py

13
README.md

@ -10,6 +10,12 @@ Run Version 2 on Colab, HuggingFace, and Replicate!
<br> <br>
For **Stable Diffusion 2.0** prompting use the `ViT-H` version:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/open-clip/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/fffiloni/CLIP-Interrogator-2)
<br>
Version 1 still available in Colab for comparing different CLIP models Version 1 still available in Colab for comparing different CLIP models
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb)
@ -30,7 +36,6 @@ source ci_env/bin/activate
Install with PIP Install with PIP
``` ```
pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip
pip install clip-interrogator pip install clip-interrogator
``` ```
@ -40,6 +45,10 @@ You can then use it in your script
from PIL import Image from PIL import Image
from clip_interrogator import Interrogator, Config from clip_interrogator import Interrogator, Config
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert('RGB')
ci = Interrogator(Config(clip_model_name="ViT-L/14")) ci = Interrogator(Config(clip_model_name="ViT-L-14/openai"))
print(ci.interrogate(image)) print(ci.interrogate(image))
``` ```
CLIP Interrogator uses OpenCLIP which supports many different pretrained CLIP models. For the best prompts for
Stable Diffusion 1.X use `ViT-L-14/openai` for clip_model_name. For Stable Diffusion 2.0 use `ViT-H-14/laion2b_s32b_b79k`

46
clip_interrogator.ipynb

@ -46,12 +46,12 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Setup\n", "#@title Setup\n",
"import subprocess\n", "import os, subprocess\n",
"\n", "\n",
"def setup():\n", "def setup():\n",
" install_cmds = [\n", " install_cmds = [\n",
" ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n", " ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n",
" ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n", " ]\n",
@ -60,20 +60,41 @@
"\n", "\n",
"setup()\n", "setup()\n",
"\n", "\n",
"# download cache files\n",
"print(\"Download preprocessed cache files...\")\n",
"CACHE_URLS = [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
"]\n",
"os.makedirs('cache', exist_ok=True)\n",
"for url in CACHE_URLS:\n",
" print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"import sys\n", "import sys\n",
"sys.path.append('src/blip')\n", "sys.path.append('src/blip')\n",
"sys.path.append('src/clip')\n",
"sys.path.append('clip-interrogator')\n", "sys.path.append('clip-interrogator')\n",
"\n", "\n",
"import gradio as gr\n", "import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n", "from clip_interrogator import Config, Interrogator\n",
"\n", "\n",
"ci = Interrogator(Config())\n", "config = Config()\n",
"config.blip_num_beams = 64\n",
"config.blip_offload = False\n",
"config.chunk_size = 2048\n",
"config.flavor_intermediate_count = 2048\n",
"\n",
"ci = Interrogator(config)\n",
"\n", "\n",
"def inference(image, mode):\n", "def inference(image, mode, clip_model_name, best_max_flavors=32):\n",
" if clip_model_name != ci.config.clip_model_name:\n",
" ci.config.clip_model_name = clip_model_name\n",
" ci.load_clip_model()\n",
" image = image.convert('RGB')\n", " image = image.convert('RGB')\n",
" if mode == 'best':\n", " if mode == 'best':\n",
" return ci.interrogate(image)\n", " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n",
" elif mode == 'classic':\n", " elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n", " return ci.interrogate_classic(image)\n",
" else:\n", " else:\n",
@ -132,6 +153,8 @@
"inputs = [\n", "inputs = [\n",
" gr.inputs.Image(type='pil'),\n", " gr.inputs.Image(type='pil'),\n",
" gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n", " gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n",
" gr.Dropdown([\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"], value='ViT-L-14/openai', label='CLIP Model'),\n",
" gr.Number(value=16, label='best mode max flavors'),\n",
"]\n", "]\n",
"outputs = [\n", "outputs = [\n",
" gr.outputs.Textbox(label=\"Output\"),\n", " gr.outputs.Textbox(label=\"Output\"),\n",
@ -170,9 +193,10 @@
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"prompt_mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n", "prompt_mode = 'best' #@param [\"best\",\"fast\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n", "max_filename_len = 128 #@param {type:\"integer\"}\n",
"best_max_flavors = 16 #@param {type:\"integer\"}\n",
"\n", "\n",
"\n", "\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
@ -189,7 +213,7 @@
" clear_output(wait=True)\n", " clear_output(wait=True)\n",
"\n", "\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = inference(image, prompt_mode)\n", " prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n",
" prompts.append(prompt)\n", " prompts.append(prompt)\n",
"\n", "\n",
" print(prompt)\n", " print(prompt)\n",
@ -232,7 +256,7 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3.9.5 ('venv': venv)", "display_name": "Python 3.8.10 ('ci')",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -246,12 +270,12 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.5" "version": "3.8.10"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "10f7ca63a88f18f789e6adaf7a045f1bcd3706c5534a32f168d622925241605d" "hash": "90daa5087f97972f35e673cab20894a33c1e0ca77092ccdd163e60b53596983a"
} }
} }
}, },

134
clip_interrogator/clip_interrogator.py

@ -1,10 +1,11 @@
import clip
import hashlib import hashlib
import inspect import inspect
import math import math
import numpy as np import numpy as np
import open_clip
import os import os
import pickle import pickle
import time
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
@ -28,9 +29,11 @@ class Config:
blip_max_length: int = 32 blip_max_length: int = 32
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
blip_num_beams: int = 8 blip_num_beams: int = 8
blip_offload: bool = False
# clip settings # clip settings
clip_model_name: str = 'ViT-L/14' clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k'
clip_model_path: str = None
# interrogator settings # interrogator settings
cache_path: str = 'cache' cache_path: str = 'cache'
@ -64,14 +67,30 @@ class Interrogator():
else: else:
self.blip_model = config.blip_model self.blip_model = config.blip_model
self.load_clip_model()
def load_clip_model(self):
start_time = time.time()
config = self.config
if config.clip_model is None: if config.clip_model is None:
if not config.quiet: if not config.quiet:
print("Loading CLIP model...") print("Loading CLIP model...")
self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device)
self.clip_model.to(config.device).eval() clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name,
pretrained=clip_model_pretrained_name,
precision='fp16',
device=config.device,
jit=False,
cache_dir=config.clip_model_path
)
self.clip_model.half().to(config.device).eval()
else: else:
self.clip_model = config.clip_model self.clip_model = config.clip_model
self.clip_preprocess = config.clip_preprocess self.clip_preprocess = config.clip_preprocess
self.tokenize = open_clip.get_tokenizer(clip_model_name)
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
trending_list = [site for site in sites] trending_list = [site for site in sites]
@ -83,13 +102,19 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists] artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists]) artists.extend([f"inspired by {a}" for a in raw_artists])
self.artists = LabelTable(artists, "artists", self.clip_model, config) self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, config) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, config) self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, config) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
end_time = time.time()
if not config.quiet:
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
if self.config.blip_offload:
self.blip_model = self.blip_model.to(self.device)
size = self.config.blip_image_eval_size size = self.config.blip_image_eval_size
gpu_image = transforms.Compose([ gpu_image = transforms.Compose([
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
@ -105,12 +130,14 @@ class Interrogator():
max_length=self.config.blip_max_length, max_length=self.config.blip_max_length,
min_length=5 min_length=5
) )
if self.config.blip_offload:
self.blip_model = self.blip_model.to("cpu")
return caption[0] return caption[0]
def image_to_features(self, image: Image) -> torch.Tensor: def image_to_features(self, image: Image) -> torch.Tensor:
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad(): with torch.no_grad(), torch.cuda.amp.autocast():
image_features = self.clip_model.encode_image(images).float() image_features = self.clip_model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features return image_features
@ -129,14 +156,14 @@ class Interrogator():
else: else:
prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}"
return _truncate_to_fit(prompt) return _truncate_to_fit(prompt, self.tokenize)
def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, max_flavors) tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops)) return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
def interrogate(self, image: Image, max_flavors: int=32) -> str: def interrogate(self, image: Image, max_flavors: int=32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
@ -171,7 +198,7 @@ class Interrogator():
prompt += ", " + opts[bit] prompt += ", " + opts[bit]
prompts.append(prompt) prompts.append(prompt)
t = LabelTable(prompts, None, self.clip_model, self.config) t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
best_prompt = t.rank(image_features, 1)[0] best_prompt = t.rank(image_features, 1)[0]
best_sim = self.similarity(image_features, best_prompt) best_sim = self.similarity(image_features, best_prompt)
@ -179,47 +206,41 @@ class Interrogator():
extended_flavors = set(flaves) extended_flavors = set(flaves)
for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet): for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
try:
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
flave = best[len(best_prompt)+2:] flave = best[len(best_prompt)+2:]
if not check(flave): if not check(flave):
break break
extended_flavors.remove(flave) if _prompt_at_max_len(best_prompt, self.tokenize):
except:
# exceeded max prompt length
break break
extended_flavors.remove(flave)
return best_prompt return best_prompt
def rank_top(self, image_features, text_array: List[str]) -> str: def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str:
text_tokens = clip.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens).float() text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
return text_array[similarity.argmax().item()]
similarity = torch.zeros((1, len(text_array)), device=self.device) def similarity(self, image_features: torch.Tensor, text: str) -> float:
for i in range(image_features.shape[0]): text_tokens = self.tokenize([text]).to(self.device)
similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
_, top_labels = similarity.cpu().topk(1, dim=-1)
return text_array[top_labels[0][0].numpy()]
def similarity(self, image_features, text) -> np.float32:
text_tokens = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_features = self.clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T similarity = text_features @ image_features.T
return similarity[0][0] return similarity[0][0].item()
class LabelTable(): class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, config: Config): def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
self.chunk_size = config.chunk_size self.chunk_size = config.chunk_size
self.config = config self.config = config
self.device = config.device self.device = config.device
self.embeds = [] self.embeds = []
self.labels = labels self.labels = labels
self.tokenize = tokenize
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
@ -239,9 +260,9 @@ class LabelTable():
self.embeds = [] self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
text_tokens = clip.tokenize(chunk).to(self.device) text_tokens = self.tokenize(chunk).to(self.device)
with torch.no_grad(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text_tokens).float() text_features = clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.half().cpu().numpy() text_features = text_features.half().cpu().numpy()
for i in range(text_features.shape[0]): for i in range(text_features.shape[0]):
@ -256,16 +277,15 @@ class LabelTable():
"model": config.clip_model_name "model": config.clip_model_name
}, f) }, f)
def _rank(self, image_features, text_embeds, top_count=1): def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str:
top_count = min(top_count, len(text_embeds)) top_count = min(top_count, len(text_embeds))
similarity = torch.zeros((1, len(text_embeds))).to(self.device) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(self.device) with torch.cuda.amp.autocast():
for i in range(image_features.shape[0]): similarity = image_features @ text_embeds.T
similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1) _, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
_, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)] return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features, top_count=1) -> List[str]: def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]:
if len(self.labels) <= self.chunk_size: if len(self.labels) <= self.chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count) tops = self._rank(image_features, self.embeds, top_count=top_count)
return [self.labels[i] for i in tops] return [self.labels[i] for i in tops]
@ -285,23 +305,27 @@ class LabelTable():
return [top_labels[i] for i in tops] return [top_labels[i] for i in tops]
def _load_list(data_path, filename) -> List[str]: def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()] items = [line.strip() for line in f.readlines()]
return items return items
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, config) m = LabelTable([], None, None, None, config)
for table in tables: for table in tables:
m.labels.extend(table.labels) m.labels.extend(table.labels)
m.embeds.extend(table.embeds) m.embeds.extend(table.embeds)
return m return m
def _truncate_to_fit(text: str) -> str: def _prompt_at_max_len(text: str, tokenize) -> bool:
while True: tokens = tokenize([text])
try: return tokens[0][-1] != 0
_ = clip.tokenize([text])
return text
except:
text = ",".join(text.split(",")[:-1])
def _truncate_to_fit(text: str, tokenize) -> str:
parts = text.split(', ')
new_text = parts[0]
for part in parts[1:]:
if _prompt_at_max_len(new_text + part, tokenize):
break
new_text += ', ' + part
return new_text

13
cog.yaml

@ -1,6 +1,6 @@
build: build:
gpu: true gpu: true
cuda: "11.3" cuda: "11.6"
python_version: "3.8" python_version: "3.8"
system_packages: system_packages:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
@ -10,11 +10,12 @@ build:
- "fairscale==0.4.12" - "fairscale==0.4.12"
- "transformers==4.21.2" - "transformers==4.21.2"
- "ftfy==6.1.1" - "ftfy==6.1.1"
- "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - "torch==1.13.0 --extra-index-url=https://download.pytorch.org/whl/cu116"
- "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - "torchvision==0.14.0 --extra-index-url=https://download.pytorch.org/whl/cu116"
- "open_clip_torch==2.7.0"
- "timm==0.4.12"
- "pycocoevalcap==1.2"
run: run:
- pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip - git clone https://github.com/salesforce/BLIP /root/blip
- pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
predict: "predict.py:Predictor" predict: "predict.py:Predictor"

34
predict.py

@ -2,17 +2,43 @@ import sys
from PIL import Image from PIL import Image
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path
sys.path.extend(["src/clip", "src/blip"]) sys.path.append('/root/blip')
from clip_interrogator import Interrogator, Config from clip_interrogator import Interrogator, Config
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
config = Config(device="cuda:0", clip_model_name='ViT-L/14') self.ci = Interrogator(Config(
self.ci = Interrogator(config) blip_model_url='cache/model_large_caption.pth',
clip_model_name="ViT-L-14/openai",
clip_model_path='cache',
device='cuda:0',
))
def predict(self, image: Path = Input(description="Input image")) -> str: def predict(
self,
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L-14/openai",
choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"],
description="Choose ViT-L for Stable Diffusion 1, and ViT-H for Stable Diffusion 2",
),
mode: str = Input(
default="best",
choices=["best", "fast"],
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
),
) -> str:
"""Run a single prediction on the model""" """Run a single prediction on the model"""
image = Image.open(str(image)).convert("RGB") image = Image.open(str(image)).convert("RGB")
self.switch_model(clip_model_name)
if mode == "best":
return self.ci.interrogate(image) return self.ci.interrogate(image)
else:
return self.ci.interrogate_fast(image)
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name
self.ci.load_clip_model()

1
requirements.txt

@ -3,3 +3,4 @@ torchvision
Pillow Pillow
requests requests
tqdm tqdm
open_clip_torch

9
run_cli.py

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import clip
import csv import csv
import open_clip
import os import os
import requests import requests
import torch import torch
@ -19,7 +19,7 @@ def inference(ci, image, mode):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use') parser.add_argument('-c', '--clip', default='ViT-L-14/openai', help='name of CLIP model to use')
parser.add_argument('-f', '--folder', help='path to folder of images') parser.add_argument('-f', '--folder', help='path to folder of images')
parser.add_argument('-i', '--image', help='image file or url') parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
@ -34,9 +34,10 @@ def main():
exit(1) exit(1)
# validate clip model name # validate clip model name
if args.clip not in clip.available_models(): models = ['/'.join(x) for x in open_clip.list_pretrained()]
if args.clip not in models:
print(f"Could not find CLIP model {args.clip}!") print(f"Could not find CLIP model {args.clip}!")
print(f" available models: {clip.available_models()}") print(f" available models: {models}")
exit(1) exit(1)
# generate a nice prompt # generate a nice prompt

20
run_gradio.py

@ -1,14 +1,19 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import clip import argparse
import gradio as gr import gradio as gr
import open_clip
from clip_interrogator import Interrogator, Config from clip_interrogator import Interrogator, Config
ci = Interrogator(Config()) parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
args = parser.parse_args()
ci = Interrogator(Config(cache_path="cache", clip_model_path="cache"))
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
global ci
if clip_model_name != ci.config.clip_model_name: if clip_model_name != ci.config.clip_model_name:
ci = Interrogator(Config(clip_model_name=clip_model_name)) ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
ci.config.blip_max_length = int(blip_max_length) ci.config.blip_max_length = int(blip_max_length)
ci.config.blip_num_beams = int(blip_num_beams) ci.config.blip_num_beams = int(blip_num_beams)
@ -20,10 +25,12 @@ def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
else: else:
return ci.interrogate_fast(image) return ci.interrogate_fast(image)
models = ['/'.join(x) for x in open_clip.list_pretrained()]
inputs = [ inputs = [
gr.inputs.Image(type='pil'), gr.inputs.Image(type='pil'),
gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'),
gr.Dropdown(clip.available_models(), value='ViT-L/14', label='CLIP Model'), gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'),
gr.Number(value=32, label='Caption Max Length'), gr.Number(value=32, label='Caption Max Length'),
gr.Number(value=64, label='Caption Num Beams'), gr.Number(value=64, label='Caption Num Beams'),
] ]
@ -38,4 +45,5 @@ io = gr.Interface(
title="🕵 CLIP Interrogator 🕵", title="🕵 CLIP Interrogator 🕵",
allow_flagging=False, allow_flagging=False,
) )
io.launch() io.launch(share=args.share)

Loading…
Cancel
Save