diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb index 809ab9c..af5d8e4 100644 --- a/clip_interrogator.ipynb +++ b/clip_interrogator.ipynb @@ -1,29 +1,19 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "3jm8RYrLqvzz" }, "source": [ - "# CLIP Interrogator 2.2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", + "# CLIP Interrogator 2.3 [negative prompt experiment!]\n", "\n", - "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", + "This experimental version of CLIP Interrogator supports finding good \"negative\" prompts for Stable Diffusion 2. Note this is very *WIP* and more work needs to be done building out the dataset to support this (and perhaps a reverse BLIP) so for many images it may struggle to find a well aligned negative prompt. Alignments are displayed to help see how well it did.\n", "\n", "
\n", "\n", - "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n", - "\n", - "This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n", - "\n", - "You can also run this on HuggingFace and Replicate
\n", - "[![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator)\n", - "\n", - "
\n", - "\n", - "If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n", - "\n", - "And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n" + "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n" ] }, { @@ -56,7 +46,7 @@ " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", " ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", - " ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", + " ['git', 'clone', '-b', 'negative', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ]\n", " for cmd in install_cmds:\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", @@ -64,7 +54,7 @@ "setup()\n", "\n", "\n", - "clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n", + "clip_model_name = 'ViT-H-14/laion2b_s32b_b79k' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n", "\n", "\n", "print(\"Download preprocessed cache files...\")\n", @@ -99,16 +89,24 @@ "config.clip_model_name = clip_model_name\n", "ci = Interrogator(config)\n", "\n", - "def inference(image, mode, best_max_flavors=32):\n", + "def inference(image, mode):\n", " ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " image = image.convert('RGB')\n", + " prompt = \"\"\n", " if mode == 'best':\n", - " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", + " prompt = ci.interrogate(image)\n", " elif mode == 'classic':\n", - " return ci.interrogate_classic(image)\n", - " else:\n", - " return ci.interrogate_fast(image)\n" + " prompt = ci.interrogate_classic(image)\n", + " elif mode == 'fast':\n", + " prompt = ci.interrogate_fast(image)\n", + " elif mode == 'negative':\n", + " image_features = ci.image_to_features(image)\n", + " flaves = ci.flavors.rank(image_features, ci.config.flavor_intermediate_count, reverse=True)\n", + " flaves = flaves + ci.negative.labels\n", + " prompt = ci.chain(image_features, flaves, max_count=32, reverse=True, desc=\"Negative chain\")\n", + " sim = ci.similarity(ci.image_to_features(image), prompt)\n", + " return prompt, sim\n" ] }, { @@ -162,11 +160,11 @@ " \n", "inputs = [\n", " gr.inputs.Image(type='pil'),\n", - " gr.Radio(['best', 'fast'], label='', value='best'),\n", - " gr.Number(value=16, label='best mode max flavors'),\n", + " gr.Radio(['best', 'fast', 'negative'], label='Mode', value='best'),\n", "]\n", "outputs = [\n", " gr.outputs.Textbox(label=\"Output\"),\n", + " gr.Number(label=\"Alignment\"),\n", "]\n", "\n", "io = gr.Interface(\n", @@ -279,7 +277,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.15" + "version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]" }, "orig_nbformat": 4, "vscode": { diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index a978e43..6078288 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -107,11 +107,51 @@ class Interrogator(): self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) + self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config) end_time = time.time() if not config.quiet: print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") + def chain( + self, + image_features: torch.Tensor, + phrases: List[str], + best_prompt: str="", + best_sim: float=0, + max_count: int=32, + desc="Chaining", + reverse: bool=False + ) -> str: + phrases = set(phrases) + if not best_prompt: + best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse) + best_sim = self.similarity(image_features, best_prompt) + phrases.remove(best_prompt) + + def check(addition: str) -> bool: + nonlocal best_prompt, best_sim + prompt = best_prompt + ", " + addition + sim = self.similarity(image_features, prompt) + if reverse: + sim = -sim + if sim > best_sim: + best_sim = sim + best_prompt = prompt + return True + return False + + for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet): + best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse) + flave = best[len(best_prompt)+2:] + if not check(flave): + break + if _prompt_at_max_len(best_prompt, self.tokenize): + break + phrases.remove(flave) + + return best_prompt + def generate_caption(self, pil_image: Image) -> str: if self.config.blip_offload: self.blip_model = self.blip_model.to(self.device) @@ -204,24 +244,16 @@ class Interrogator(): check_multi_batch([best_medium, best_artist, best_trending, best_movement]) - extended_flavors = set(flaves) - for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet): - best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) - flave = best[len(best_prompt)+2:] - if not check(flave): - break - if _prompt_at_max_len(best_prompt, self.tokenize): - break - extended_flavors.remove(flave) - - return best_prompt + return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain") - def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str: + def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: text_tokens = self.tokenize([text for text in text_array]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T + if reverse: + similarity = -similarity return text_array[similarity.argmax().item()] def similarity(self, image_features: torch.Tensor, text: str) -> float: @@ -283,17 +315,19 @@ class LabelTable(): if self.device == 'cpu' or self.device == torch.device('cpu'): self.embeds = [e.astype(np.float32) for e in self.embeds] - def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: + def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: top_count = min(top_count, len(text_embeds)) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) with torch.cuda.amp.autocast(): similarity = image_features @ text_embeds.T + if reverse: + similarity = -similarity _, top_labels = similarity.float().cpu().topk(top_count, dim=-1) return [top_labels[0][i].numpy() for i in range(top_count)] - def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]: + def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]: if len(self.labels) <= self.chunk_size: - tops = self._rank(image_features, self.embeds, top_count=top_count) + tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse) return [self.labels[i] for i in tops] num_chunks = int(math.ceil(len(self.labels)/self.chunk_size)) @@ -303,7 +337,7 @@ class LabelTable(): for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): start = chunk_idx*self.chunk_size stop = min(start+self.chunk_size, len(self.embeds)) - tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) + tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse) top_labels.extend([self.labels[start+i] for i in tops]) top_embeds.extend([self.embeds[start+i] for i in tops]) diff --git a/clip_interrogator/data/negative.txt b/clip_interrogator/data/negative.txt new file mode 100644 index 0000000..7d39d47 --- /dev/null +++ b/clip_interrogator/data/negative.txt @@ -0,0 +1,41 @@ +3d +b&w +bad anatomy +bad art +blur +blurry +cartoon +childish +close up +deformed +disconnected limbs +disfigured +disgusting +extra limb +extra limbs +floating limbs +grain +illustration +kitsch +long body +long neck +low quality +low-res +malformed hands +mangled +missing limb +mutated +mutation +mutilated +noisy +old +out of focus +over saturation +oversaturated +poorly drawn +poorly drawn face +poorly drawn hands +render +surreal +ugly +weird colors \ No newline at end of file