Browse Source

Experimental negative prompt generator

pull/40/head
pharmapsychotic 2 years ago
parent
commit
93b7ad0fbc
  1. 46
      clip_interrogator.ipynb
  2. 66
      clip_interrogator/clip_interrogator.py
  3. 41
      clip_interrogator/data/negative.txt

46
clip_interrogator.ipynb

@ -1,29 +1,19 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "3jm8RYrLqvzz" "id": "3jm8RYrLqvzz"
}, },
"source": [ "source": [
"# CLIP Interrogator 2.2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", "# CLIP Interrogator 2.3 [negative prompt experiment!]\n",
"\n", "\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "This experimental version of CLIP Interrogator supports finding good \"negative\" prompts for Stable Diffusion 2. Note this is very *WIP* and more work needs to be done building out the dataset to support this (and perhaps a reverse BLIP) so for many images it may struggle to find a well aligned negative prompt. Alignments are displayed to help see how well it did.\n",
"\n", "\n",
"<br>\n", "<br>\n",
"\n", "\n",
"For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n", "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n"
"\n",
"This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n",
"\n",
"You can also run this on HuggingFace and Replicate<br>\n",
"[![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator)\n",
"\n",
"<br>\n",
"\n",
"If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
"\n",
"And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n"
] ]
}, },
{ {
@ -56,7 +46,7 @@
" ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n",
" ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n",
" ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ['git', 'clone', '-b', 'negative', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n", " ]\n",
" for cmd in install_cmds:\n", " for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
@ -64,7 +54,7 @@
"setup()\n", "setup()\n",
"\n", "\n",
"\n", "\n",
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n", "clip_model_name = 'ViT-H-14/laion2b_s32b_b79k' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n", "\n",
"\n", "\n",
"print(\"Download preprocessed cache files...\")\n", "print(\"Download preprocessed cache files...\")\n",
@ -99,16 +89,24 @@
"config.clip_model_name = clip_model_name\n", "config.clip_model_name = clip_model_name\n",
"ci = Interrogator(config)\n", "ci = Interrogator(config)\n",
"\n", "\n",
"def inference(image, mode, best_max_flavors=32):\n", "def inference(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n", " image = image.convert('RGB')\n",
" prompt = \"\"\n",
" if mode == 'best':\n", " if mode == 'best':\n",
" return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", " prompt = ci.interrogate(image)\n",
" elif mode == 'classic':\n", " elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n", " prompt = ci.interrogate_classic(image)\n",
" else:\n", " elif mode == 'fast':\n",
" return ci.interrogate_fast(image)\n" " prompt = ci.interrogate_fast(image)\n",
" elif mode == 'negative':\n",
" image_features = ci.image_to_features(image)\n",
" flaves = ci.flavors.rank(image_features, ci.config.flavor_intermediate_count, reverse=True)\n",
" flaves = flaves + ci.negative.labels\n",
" prompt = ci.chain(image_features, flaves, max_count=32, reverse=True, desc=\"Negative chain\")\n",
" sim = ci.similarity(ci.image_to_features(image), prompt)\n",
" return prompt, sim\n"
] ]
}, },
{ {
@ -162,11 +160,11 @@
" \n", " \n",
"inputs = [\n", "inputs = [\n",
" gr.inputs.Image(type='pil'),\n", " gr.inputs.Image(type='pil'),\n",
" gr.Radio(['best', 'fast'], label='', value='best'),\n", " gr.Radio(['best', 'fast', 'negative'], label='Mode', value='best'),\n",
" gr.Number(value=16, label='best mode max flavors'),\n",
"]\n", "]\n",
"outputs = [\n", "outputs = [\n",
" gr.outputs.Textbox(label=\"Output\"),\n", " gr.outputs.Textbox(label=\"Output\"),\n",
" gr.Number(label=\"Alignment\"),\n",
"]\n", "]\n",
"\n", "\n",
"io = gr.Interface(\n", "io = gr.Interface(\n",
@ -279,7 +277,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.15" "version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {

66
clip_interrogator/clip_interrogator.py

@ -107,11 +107,51 @@ class Interrogator():
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)
end_time = time.time() end_time = time.time()
if not config.quiet: if not config.quiet:
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
def chain(
self,
image_features: torch.Tensor,
phrases: List[str],
best_prompt: str="",
best_sim: float=0,
max_count: int=32,
desc="Chaining",
reverse: bool=False
) -> str:
phrases = set(phrases)
if not best_prompt:
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
best_sim = self.similarity(image_features, best_prompt)
phrases.remove(best_prompt)
def check(addition: str) -> bool:
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if reverse:
sim = -sim
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse)
flave = best[len(best_prompt)+2:]
if not check(flave):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
break
phrases.remove(flave)
return best_prompt
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
if self.config.blip_offload: if self.config.blip_offload:
self.blip_model = self.blip_model.to(self.device) self.blip_model = self.blip_model.to(self.device)
@ -204,24 +244,16 @@ class Interrogator():
check_multi_batch([best_medium, best_artist, best_trending, best_movement]) check_multi_batch([best_medium, best_artist, best_trending, best_movement])
extended_flavors = set(flaves) return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain")
for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
flave = best[len(best_prompt)+2:]
if not check(flave):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
break
extended_flavors.remove(flave)
return best_prompt
def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str: def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
if reverse:
similarity = -similarity
return text_array[similarity.argmax().item()] return text_array[similarity.argmax().item()]
def similarity(self, image_features: torch.Tensor, text: str) -> float: def similarity(self, image_features: torch.Tensor, text: str) -> float:
@ -283,17 +315,19 @@ class LabelTable():
if self.device == 'cpu' or self.device == torch.device('cpu'): if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds] self.embeds = [e.astype(np.float32) for e in self.embeds]
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
top_count = min(top_count, len(text_embeds)) top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
similarity = image_features @ text_embeds.T similarity = image_features @ text_embeds.T
if reverse:
similarity = -similarity
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1) _, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)] return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]: def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]:
if len(self.labels) <= self.chunk_size: if len(self.labels) <= self.chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count) tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse)
return [self.labels[i] for i in tops] return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size)) num_chunks = int(math.ceil(len(self.labels)/self.chunk_size))
@ -303,7 +337,7 @@ class LabelTable():
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
start = chunk_idx*self.chunk_size start = chunk_idx*self.chunk_size
stop = min(start+self.chunk_size, len(self.embeds)) stop = min(start+self.chunk_size, len(self.embeds))
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse)
top_labels.extend([self.labels[start+i] for i in tops]) top_labels.extend([self.labels[start+i] for i in tops])
top_embeds.extend([self.embeds[start+i] for i in tops]) top_embeds.extend([self.embeds[start+i] for i in tops])

41
clip_interrogator/data/negative.txt

@ -0,0 +1,41 @@
3d
b&w
bad anatomy
bad art
blur
blurry
cartoon
childish
close up
deformed
disconnected limbs
disfigured
disgusting
extra limb
extra limbs
floating limbs
grain
illustration
kitsch
long body
long neck
low quality
low-res
malformed hands
mangled
missing limb
mutated
mutation
mutilated
noisy
old
out of focus
over saturation
oversaturated
poorly drawn
poorly drawn face
poorly drawn hands
render
surreal
ugly
weird colors
Loading…
Cancel
Save