diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb index 7329580..4e78074 100644 --- a/clip_interrogator.ipynb +++ b/clip_interrogator.ipynb @@ -6,7 +6,7 @@ "id": "3jm8RYrLqvzz" }, "source": [ - "# CLIP Interrogator 2.1 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", + "# CLIP Interrogator 2.1 ViTH special edition!\n", "\n", "
\n", "\n", @@ -14,13 +14,7 @@ "\n", "
\n", "\n", - "This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n", - "\n", - "
\n", - "\n", - "If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n", - "\n", - "And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n" + "This version is specialized for producing nice prompts for use with **Stable Diffusion 2.0** using the ViT-H-14 OpenCLIP model!" ] }, { @@ -46,23 +40,35 @@ "outputs": [], "source": [ "#@title Setup\n", - "import subprocess\n", + "import os, subprocess\n", "\n", "def setup():\n", " install_cmds = [\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", - " ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n", + " ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", - " ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", + " ['git', 'clone', '-b', 'open-clip', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ]\n", " for cmd in install_cmds:\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", "\n", "setup()\n", "\n", + "# download cache files\n", + "print(\"Download preprocessed cache files...\")\n", + "CACHE_URLS = [\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n", + "]\n", + "os.makedirs('cache', exist_ok=True)\n", + "for url in CACHE_URLS:\n", + " print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", + "\n", "import sys\n", "sys.path.append('src/blip')\n", - "sys.path.append('src/clip')\n", "sys.path.append('clip-interrogator')\n", "\n", "import gradio as gr\n", @@ -232,7 +238,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3.9.5 ('venv': venv)", + "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, @@ -246,12 +252,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "10f7ca63a88f18f789e6adaf7a045f1bcd3706c5534a32f168d622925241605d" + "hash": "f7a8d9541664ade9cff251487a19c76f2dd1b4c864d158f07ee26d1b0fd5c9a1" } } }, diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index fb26215..dd5bfa0 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -1,8 +1,8 @@ -import clip import hashlib import inspect import math import numpy as np +import open_clip import os import pickle import torch @@ -30,7 +30,7 @@ class Config: blip_num_beams: int = 8 # clip settings - clip_model_name: str = 'ViT-L/14' + clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k' # interrogator settings cache_path: str = 'cache' @@ -67,11 +67,14 @@ class Interrogator(): if config.clip_model is None: if not config.quiet: print("Loading CLIP model...") - self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device) + + clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) + self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_model_name, pretrained=clip_model_pretrained_name) self.clip_model.to(config.device).eval() else: self.clip_model = config.clip_model self.clip_preprocess = config.clip_preprocess + self.tokenize = open_clip.get_tokenizer(clip_model_name) sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] trending_list = [site for site in sites] @@ -83,11 +86,11 @@ class Interrogator(): artists = [f"by {a}" for a in raw_artists] artists.extend([f"inspired by {a}" for a in raw_artists]) - self.artists = LabelTable(artists, "artists", self.clip_model, config) - self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, config) - self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, config) - self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, config) - self.trendings = LabelTable(trending_list, "trendings", self.clip_model, config) + self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) + self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) + self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) + self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) + self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) def generate_caption(self, pil_image: Image) -> str: size = self.config.blip_image_eval_size @@ -129,14 +132,14 @@ class Interrogator(): else: prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" - return _truncate_to_fit(prompt) + return _truncate_to_fit(prompt, self.tokenize) def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: caption = self.generate_caption(image) image_features = self.image_to_features(image) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) tops = merged.rank(image_features, max_flavors) - return _truncate_to_fit(caption + ", " + ", ".join(tops)) + return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) def interrogate(self, image: Image, max_flavors: int=32) -> str: caption = self.generate_caption(image) @@ -171,7 +174,7 @@ class Interrogator(): prompt += ", " + opts[bit] prompts.append(prompt) - t = LabelTable(prompts, None, self.clip_model, self.config) + t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config) best_prompt = t.rank(image_features, 1)[0] best_sim = self.similarity(image_features, best_prompt) @@ -192,7 +195,7 @@ class Interrogator(): return best_prompt def rank_top(self, image_features, text_array: List[str]) -> str: - text_tokens = clip.tokenize([text for text in text_array]).to(self.device) + text_tokens = self.tokenize([text for text in text_array]).to(self.device) with torch.no_grad(): text_features = self.clip_model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) @@ -205,7 +208,7 @@ class Interrogator(): return text_array[top_labels[0][0].numpy()] def similarity(self, image_features, text) -> np.float32: - text_tokens = clip.tokenize([text]).to(self.device) + text_tokens = self.tokenize([text]).to(self.device) with torch.no_grad(): text_features = self.clip_model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) @@ -214,12 +217,13 @@ class Interrogator(): class LabelTable(): - def __init__(self, labels:List[str], desc:str, clip_model, config: Config): + def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): self.chunk_size = config.chunk_size self.config = config self.device = config.device self.embeds = [] self.labels = labels + self.tokenize = tokenize hash = hashlib.sha256(",".join(labels).encode()).hexdigest() @@ -239,7 +243,7 @@ class LabelTable(): self.embeds = [] chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): - text_tokens = clip.tokenize(chunk).to(self.device) + text_tokens = self.tokenize(chunk).to(self.device) with torch.no_grad(): text_features = clip_model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) @@ -291,16 +295,16 @@ def _load_list(data_path, filename) -> List[str]: return items def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: - m = LabelTable([], None, None, config) + m = LabelTable([], None, None, None, config) for table in tables: m.labels.extend(table.labels) m.embeds.extend(table.embeds) return m -def _truncate_to_fit(text: str) -> str: +def _truncate_to_fit(text: str, tokenize) -> str: while True: try: - _ = clip.tokenize([text]) + _ = tokenize([text]) return text except: text = ",".join(text.split(",")[:-1])