diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb
index 809ab9c..af5d8e4 100644
--- a/clip_interrogator.ipynb
+++ b/clip_interrogator.ipynb
@@ -1,29 +1,19 @@
{
"cells": [
{
+ "attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "3jm8RYrLqvzz"
},
"source": [
- "# CLIP Interrogator 2.2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
+ "# CLIP Interrogator 2.3 [negative prompt experiment!]\n",
"\n",
- "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
+ "This experimental version of CLIP Interrogator supports finding good \"negative\" prompts for Stable Diffusion 2. Note this is very *WIP* and more work needs to be done building out the dataset to support this (and perhaps a reverse BLIP) so for many images it may struggle to find a well aligned negative prompt. Alignments are displayed to help see how well it did.\n",
"\n",
"
\n",
"\n",
- "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n",
- "\n",
- "This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n",
- "\n",
- "You can also run this on HuggingFace and Replicate
\n",
- "[](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [](https://replicate.com/pharmapsychotic/clip-interrogator)\n",
- "\n",
- "
\n",
- "\n",
- "If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
- "\n",
- "And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n"
+ "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n"
]
},
{
@@ -56,7 +46,7 @@
" ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n",
- " ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
+ " ['git', 'clone', '-b', 'negative', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n",
" for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
@@ -64,7 +54,7 @@
"setup()\n",
"\n",
"\n",
- "clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
+ "clip_model_name = 'ViT-H-14/laion2b_s32b_b79k' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"\n",
"print(\"Download preprocessed cache files...\")\n",
@@ -99,16 +89,24 @@
"config.clip_model_name = clip_model_name\n",
"ci = Interrogator(config)\n",
"\n",
- "def inference(image, mode, best_max_flavors=32):\n",
+ "def inference(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n",
+ " prompt = \"\"\n",
" if mode == 'best':\n",
- " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n",
+ " prompt = ci.interrogate(image)\n",
" elif mode == 'classic':\n",
- " return ci.interrogate_classic(image)\n",
- " else:\n",
- " return ci.interrogate_fast(image)\n"
+ " prompt = ci.interrogate_classic(image)\n",
+ " elif mode == 'fast':\n",
+ " prompt = ci.interrogate_fast(image)\n",
+ " elif mode == 'negative':\n",
+ " image_features = ci.image_to_features(image)\n",
+ " flaves = ci.flavors.rank(image_features, ci.config.flavor_intermediate_count, reverse=True)\n",
+ " flaves = flaves + ci.negative.labels\n",
+ " prompt = ci.chain(image_features, flaves, max_count=32, reverse=True, desc=\"Negative chain\")\n",
+ " sim = ci.similarity(ci.image_to_features(image), prompt)\n",
+ " return prompt, sim\n"
]
},
{
@@ -162,11 +160,11 @@
" \n",
"inputs = [\n",
" gr.inputs.Image(type='pil'),\n",
- " gr.Radio(['best', 'fast'], label='', value='best'),\n",
- " gr.Number(value=16, label='best mode max flavors'),\n",
+ " gr.Radio(['best', 'fast', 'negative'], label='Mode', value='best'),\n",
"]\n",
"outputs = [\n",
" gr.outputs.Textbox(label=\"Output\"),\n",
+ " gr.Number(label=\"Alignment\"),\n",
"]\n",
"\n",
"io = gr.Interface(\n",
@@ -279,7 +277,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.15"
+ "version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]"
},
"orig_nbformat": 4,
"vscode": {
diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py
index a978e43..6078288 100644
--- a/clip_interrogator/clip_interrogator.py
+++ b/clip_interrogator/clip_interrogator.py
@@ -107,11 +107,51 @@ class Interrogator():
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
+ self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)
end_time = time.time()
if not config.quiet:
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
+ def chain(
+ self,
+ image_features: torch.Tensor,
+ phrases: List[str],
+ best_prompt: str="",
+ best_sim: float=0,
+ max_count: int=32,
+ desc="Chaining",
+ reverse: bool=False
+ ) -> str:
+ phrases = set(phrases)
+ if not best_prompt:
+ best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
+ best_sim = self.similarity(image_features, best_prompt)
+ phrases.remove(best_prompt)
+
+ def check(addition: str) -> bool:
+ nonlocal best_prompt, best_sim
+ prompt = best_prompt + ", " + addition
+ sim = self.similarity(image_features, prompt)
+ if reverse:
+ sim = -sim
+ if sim > best_sim:
+ best_sim = sim
+ best_prompt = prompt
+ return True
+ return False
+
+ for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
+ best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse)
+ flave = best[len(best_prompt)+2:]
+ if not check(flave):
+ break
+ if _prompt_at_max_len(best_prompt, self.tokenize):
+ break
+ phrases.remove(flave)
+
+ return best_prompt
+
def generate_caption(self, pil_image: Image) -> str:
if self.config.blip_offload:
self.blip_model = self.blip_model.to(self.device)
@@ -204,24 +244,16 @@ class Interrogator():
check_multi_batch([best_medium, best_artist, best_trending, best_movement])
- extended_flavors = set(flaves)
- for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
- best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
- flave = best[len(best_prompt)+2:]
- if not check(flave):
- break
- if _prompt_at_max_len(best_prompt, self.tokenize):
- break
- extended_flavors.remove(flave)
-
- return best_prompt
+ return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain")
- def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str:
+ def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
+ if reverse:
+ similarity = -similarity
return text_array[similarity.argmax().item()]
def similarity(self, image_features: torch.Tensor, text: str) -> float:
@@ -283,17 +315,19 @@ class LabelTable():
if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds]
- def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str:
+ def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
with torch.cuda.amp.autocast():
similarity = image_features @ text_embeds.T
+ if reverse:
+ similarity = -similarity
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)]
- def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]:
+ def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]:
if len(self.labels) <= self.chunk_size:
- tops = self._rank(image_features, self.embeds, top_count=top_count)
+ tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse)
return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size))
@@ -303,7 +337,7 @@ class LabelTable():
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
start = chunk_idx*self.chunk_size
stop = min(start+self.chunk_size, len(self.embeds))
- tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
+ tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse)
top_labels.extend([self.labels[start+i] for i in tops])
top_embeds.extend([self.embeds[start+i] for i in tops])
diff --git a/clip_interrogator/data/negative.txt b/clip_interrogator/data/negative.txt
new file mode 100644
index 0000000..7d39d47
--- /dev/null
+++ b/clip_interrogator/data/negative.txt
@@ -0,0 +1,41 @@
+3d
+b&w
+bad anatomy
+bad art
+blur
+blurry
+cartoon
+childish
+close up
+deformed
+disconnected limbs
+disfigured
+disgusting
+extra limb
+extra limbs
+floating limbs
+grain
+illustration
+kitsch
+long body
+long neck
+low quality
+low-res
+malformed hands
+mangled
+missing limb
+mutated
+mutation
+mutilated
+noisy
+old
+out of focus
+over saturation
+oversaturated
+poorly drawn
+poorly drawn face
+poorly drawn hands
+render
+surreal
+ugly
+weird colors
\ No newline at end of file