Browse Source

First test version with OpenCLIP and ViTH!

pull/18/head
pharmapsychotic 2 years ago
parent
commit
429d490901
  1. 36
      clip_interrogator.ipynb
  2. 40
      clip_interrogator/clip_interrogator.py

36
clip_interrogator.ipynb

@ -6,7 +6,7 @@
"id": "3jm8RYrLqvzz" "id": "3jm8RYrLqvzz"
}, },
"source": [ "source": [
"# CLIP Interrogator 2.1 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", "# CLIP Interrogator 2.1 ViTH special edition!\n",
"\n", "\n",
"<br>\n", "<br>\n",
"\n", "\n",
@ -14,13 +14,7 @@
"\n", "\n",
"<br>\n", "<br>\n",
"\n", "\n",
"This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n", "This version is specialized for producing nice prompts for use with **Stable Diffusion 2.0** using the ViT-H-14 OpenCLIP model!"
"\n",
"<br>\n",
"\n",
"If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
"\n",
"And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n"
] ]
}, },
{ {
@ -46,23 +40,35 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Setup\n", "#@title Setup\n",
"import subprocess\n", "import os, subprocess\n",
"\n", "\n",
"def setup():\n", "def setup():\n",
" install_cmds = [\n", " install_cmds = [\n",
" ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n", " ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n",
" ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ['git', 'clone', '-b', 'open-clip', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n", " ]\n",
" for cmd in install_cmds:\n", " for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n", "\n",
"setup()\n", "setup()\n",
"\n", "\n",
"# download cache files\n",
"print(\"Download preprocessed cache files...\")\n",
"CACHE_URLS = [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
"]\n",
"os.makedirs('cache', exist_ok=True)\n",
"for url in CACHE_URLS:\n",
" print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"import sys\n", "import sys\n",
"sys.path.append('src/blip')\n", "sys.path.append('src/blip')\n",
"sys.path.append('src/clip')\n",
"sys.path.append('clip-interrogator')\n", "sys.path.append('clip-interrogator')\n",
"\n", "\n",
"import gradio as gr\n", "import gradio as gr\n",
@ -232,7 +238,7 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3.9.5 ('venv': venv)", "display_name": "Python 3.8.10 ('venv': venv)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -246,12 +252,12 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.5" "version": "3.8.10"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "10f7ca63a88f18f789e6adaf7a045f1bcd3706c5534a32f168d622925241605d" "hash": "f7a8d9541664ade9cff251487a19c76f2dd1b4c864d158f07ee26d1b0fd5c9a1"
} }
} }
}, },

40
clip_interrogator/clip_interrogator.py

@ -1,8 +1,8 @@
import clip
import hashlib import hashlib
import inspect import inspect
import math import math
import numpy as np import numpy as np
import open_clip
import os import os
import pickle import pickle
import torch import torch
@ -30,7 +30,7 @@ class Config:
blip_num_beams: int = 8 blip_num_beams: int = 8
# clip settings # clip settings
clip_model_name: str = 'ViT-L/14' clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k'
# interrogator settings # interrogator settings
cache_path: str = 'cache' cache_path: str = 'cache'
@ -67,11 +67,14 @@ class Interrogator():
if config.clip_model is None: if config.clip_model is None:
if not config.quiet: if not config.quiet:
print("Loading CLIP model...") print("Loading CLIP model...")
self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device)
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_model_name, pretrained=clip_model_pretrained_name)
self.clip_model.to(config.device).eval() self.clip_model.to(config.device).eval()
else: else:
self.clip_model = config.clip_model self.clip_model = config.clip_model
self.clip_preprocess = config.clip_preprocess self.clip_preprocess = config.clip_preprocess
self.tokenize = open_clip.get_tokenizer(clip_model_name)
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
trending_list = [site for site in sites] trending_list = [site for site in sites]
@ -83,11 +86,11 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists] artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists]) artists.extend([f"inspired by {a}" for a in raw_artists])
self.artists = LabelTable(artists, "artists", self.clip_model, config) self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, config) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, config) self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, config) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
size = self.config.blip_image_eval_size size = self.config.blip_image_eval_size
@ -129,14 +132,14 @@ class Interrogator():
else: else:
prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}"
return _truncate_to_fit(prompt) return _truncate_to_fit(prompt, self.tokenize)
def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, max_flavors) tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops)) return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
def interrogate(self, image: Image, max_flavors: int=32) -> str: def interrogate(self, image: Image, max_flavors: int=32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
@ -171,7 +174,7 @@ class Interrogator():
prompt += ", " + opts[bit] prompt += ", " + opts[bit]
prompts.append(prompt) prompts.append(prompt)
t = LabelTable(prompts, None, self.clip_model, self.config) t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
best_prompt = t.rank(image_features, 1)[0] best_prompt = t.rank(image_features, 1)[0]
best_sim = self.similarity(image_features, best_prompt) best_sim = self.similarity(image_features, best_prompt)
@ -192,7 +195,7 @@ class Interrogator():
return best_prompt return best_prompt
def rank_top(self, image_features, text_array: List[str]) -> str: def rank_top(self, image_features, text_array: List[str]) -> str:
text_tokens = clip.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(): with torch.no_grad():
text_features = self.clip_model.encode_text(text_tokens).float() text_features = self.clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
@ -205,7 +208,7 @@ class Interrogator():
return text_array[top_labels[0][0].numpy()] return text_array[top_labels[0][0].numpy()]
def similarity(self, image_features, text) -> np.float32: def similarity(self, image_features, text) -> np.float32:
text_tokens = clip.tokenize([text]).to(self.device) text_tokens = self.tokenize([text]).to(self.device)
with torch.no_grad(): with torch.no_grad():
text_features = self.clip_model.encode_text(text_tokens).float() text_features = self.clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
@ -214,12 +217,13 @@ class Interrogator():
class LabelTable(): class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, config: Config): def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
self.chunk_size = config.chunk_size self.chunk_size = config.chunk_size
self.config = config self.config = config
self.device = config.device self.device = config.device
self.embeds = [] self.embeds = []
self.labels = labels self.labels = labels
self.tokenize = tokenize
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
@ -239,7 +243,7 @@ class LabelTable():
self.embeds = [] self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
text_tokens = clip.tokenize(chunk).to(self.device) text_tokens = self.tokenize(chunk).to(self.device)
with torch.no_grad(): with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float() text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
@ -291,16 +295,16 @@ def _load_list(data_path, filename) -> List[str]:
return items return items
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, config) m = LabelTable([], None, None, None, config)
for table in tables: for table in tables:
m.labels.extend(table.labels) m.labels.extend(table.labels)
m.embeds.extend(table.embeds) m.embeds.extend(table.embeds)
return m return m
def _truncate_to_fit(text: str) -> str: def _truncate_to_fit(text: str, tokenize) -> str:
while True: while True:
try: try:
_ = clip.tokenize([text]) _ = tokenize([text])
return text return text
except: except:
text = ",".join(text.split(",")[:-1]) text = ",".join(text.split(",")[:-1])

Loading…
Cancel
Save