Browse Source

Bunch of updates! (#40)

- auto download the cache files from huggingface
- experimental negative prompt mode
- slight quality and performance improvement to best mode
- analyze tab in Colab and run_gradio to get table of ranked terms
pull/43/head
pharmapsychotic 2 years ago committed by GitHub
parent
commit
42b3cf4d9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      README.md
  2. 103
      clip_interrogator.ipynb
  3. 2
      clip_interrogator/__init__.py
  4. 175
      clip_interrogator/clip_interrogator.py
  5. 41
      clip_interrogator/data/negative.txt
  6. 84
      run_gradio.py
  7. 2
      setup.py

2
README.md

@ -36,7 +36,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator # install clip-interrogator
pip install clip-interrogator==0.3.5 pip install clip-interrogator==0.4.0
``` ```
You can then use it in your script You can then use it in your script

103
clip_interrogator.ipynb

@ -1,12 +1,13 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "3jm8RYrLqvzz" "id": "3jm8RYrLqvzz"
}, },
"source": [ "source": [
"# CLIP Interrogator 2.2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", "# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n", "\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n", "\n",
@ -56,7 +57,6 @@
" ['pip', 'install', 'gradio'],\n", " ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'clip-interrogator'],\n", " ['pip', 'install', 'clip-interrogator'],\n",
" ['pip', 'install', 'git+https://github.com/pharmapsychotic/BLIP.git'],\n",
" ]\n", " ]\n",
" for cmd in install_cmds:\n", " for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
@ -67,25 +67,6 @@
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n", "clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n", "\n",
"\n", "\n",
"print(\"Download preprocessed cache files...\")\n",
"CACHE_URLS = [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n",
"] if clip_model_name == 'ViT-L-14/openai' else [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
"]\n",
"os.makedirs('cache', exist_ok=True)\n",
"for url in CACHE_URLS:\n",
" print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"\n",
"import gradio as gr\n", "import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n", "from clip_interrogator import Config, Interrogator\n",
"\n", "\n",
@ -95,16 +76,37 @@
"config.clip_model_name = clip_model_name\n", "config.clip_model_name = clip_model_name\n",
"ci = Interrogator(config)\n", "ci = Interrogator(config)\n",
"\n", "\n",
"def inference(image, mode, best_max_flavors=32):\n", "def image_analysis(image):\n",
" image = image.convert('RGB')\n",
" image_features = ci.image_to_features(image)\n",
"\n",
" top_mediums = ci.mediums.rank(image_features, 5)\n",
" top_artists = ci.artists.rank(image_features, 5)\n",
" top_movements = ci.movements.rank(image_features, 5)\n",
" top_trendings = ci.trendings.rank(image_features, 5)\n",
" top_flavors = ci.flavors.rank(image_features, 5)\n",
"\n",
" medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}\n",
" artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}\n",
" movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}\n",
" trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}\n",
" flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}\n",
" \n",
" return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks\n",
"\n",
"def image_to_prompt(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n", " image = image.convert('RGB')\n",
" if mode == 'best':\n", " if mode == 'best':\n",
" return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", " return ci.interrogate(image)\n",
" elif mode == 'classic':\n", " elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n", " return ci.interrogate_classic(image)\n",
" else:\n", " elif mode == 'fast':\n",
" return ci.interrogate_fast(image)\n" " return ci.interrogate_fast(image)\n",
" elif mode == 'negative':\n",
" return ci.interrogate_negative(image)\n",
" "
] ]
}, },
{ {
@ -156,22 +158,36 @@
"source": [ "source": [
"#@title Image to prompt! 🖼 -> 📝\n", "#@title Image to prompt! 🖼 -> 📝\n",
" \n", " \n",
"inputs = [\n", "def prompt_tab():\n",
" gr.inputs.Image(type='pil'),\n", " with gr.Column():\n",
" gr.Radio(['best', 'fast'], label='', value='best'),\n", " with gr.Row():\n",
" gr.Number(value=16, label='best mode max flavors'),\n", " image = gr.Image(type='pil', label=\"Image\")\n",
"]\n", " with gr.Column():\n",
"outputs = [\n", " mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')\n",
" gr.outputs.Textbox(label=\"Output\"),\n", " prompt = gr.Textbox(label=\"Prompt\")\n",
"]\n", " button = gr.Button(\"Generate prompt\")\n",
"\n", " button.click(image_to_prompt, inputs=[image, mode], outputs=prompt)\n",
"io = gr.Interface(\n", "\n",
" inference, \n", "def analyze_tab():\n",
" inputs, \n", " with gr.Column():\n",
" outputs, \n", " with gr.Row():\n",
" allow_flagging=False,\n", " image = gr.Image(type='pil', label=\"Image\")\n",
")\n", " with gr.Row():\n",
"io.launch(debug=False)\n" " medium = gr.Label(label=\"Medium\", num_top_classes=5)\n",
" artist = gr.Label(label=\"Artist\", num_top_classes=5) \n",
" movement = gr.Label(label=\"Movement\", num_top_classes=5)\n",
" trending = gr.Label(label=\"Trending\", num_top_classes=5)\n",
" flavor = gr.Label(label=\"Flavor\", num_top_classes=5)\n",
" button = gr.Button(\"Analyze\")\n",
" button.click(image_analysis, inputs=image, outputs=[medium, artist, movement, trending, flavor])\n",
"\n",
"with gr.Blocks() as ui:\n",
" with gr.Tab(\"Prompt\"):\n",
" prompt_tab()\n",
" with gr.Tab(\"Analyze\"):\n",
" analyze_tab()\n",
"\n",
"ui.launch(show_api=False, debug=False)\n"
] ]
}, },
{ {
@ -198,10 +214,9 @@
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"prompt_mode = 'best' #@param [\"best\",\"fast\"]\n", "prompt_mode = 'best' #@param [\"best\",\"fast\",\"classic\",\"negative\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n", "max_filename_len = 128 #@param {type:\"integer\"}\n",
"best_max_flavors = 16 #@param {type:\"integer\"}\n",
"\n", "\n",
"\n", "\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
@ -218,7 +233,7 @@
" clear_output(wait=True)\n", " clear_output(wait=True)\n",
"\n", "\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n", " prompt = image_to_prompt(image, prompt_mode)\n",
" prompts.append(prompt)\n", " prompts.append(prompt)\n",
"\n", "\n",
" print(prompt)\n", " print(prompt)\n",

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config from .clip_interrogator import Interrogator, Config
__version__ = '0.3.5' __version__ = '0.4.0'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

175
clip_interrogator/clip_interrogator.py

@ -5,6 +5,7 @@ import numpy as np
import open_clip import open_clip
import os import os
import pickle import pickle
import requests
import time import time
import torch import torch
@ -21,6 +22,23 @@ BLIP_MODELS = {
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
} }
CACHE_URLS_VITL = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',
]
CACHE_URLS_VITH = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
]
@dataclass @dataclass
class Config: class Config:
# models can optionally be passed in directly # models can optionally be passed in directly
@ -40,13 +58,15 @@ class Config:
clip_model_path: str = None clip_model_path: str = None
# interrogator settings # interrogator settings
cache_path: str = 'cache' cache_path: str = 'cache' # path to store cached text embeddings
chunk_size: int = 2048 download_cache: bool = True # when true, cached embeds are downloaded from huggingface
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), 'data') data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
flavor_intermediate_count: int = 2048 flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown quiet: bool = False # when quiet progress bars are not shown
class Interrogator(): class Interrogator():
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
@ -72,6 +92,21 @@ class Interrogator():
self.load_clip_model() self.load_clip_model()
def download_cache(self, clip_model_name: str):
if clip_model_name == 'ViT-L-14/openai':
cache_urls = CACHE_URLS_VITL
elif clip_model_name == 'ViT-H-14/laion2b_s32b_b79k':
cache_urls = CACHE_URLS_VITH
else:
# text embeddings will be precomputed and cached locally
return
os.makedirs(self.config.cache_path, exist_ok=True)
for url in cache_urls:
filepath = os.path.join(self.config.cache_path, url.split('/')[-1])
if not os.path.exists(filepath):
_download_file(url, filepath, quiet=self.config.quiet)
def load_clip_model(self): def load_clip_model(self):
start_time = time.time() start_time = time.time()
config = self.config config = self.config
@ -105,16 +140,58 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists] artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists]) artists.extend([f"inspired by {a}" for a in raw_artists])
self.download_cache(config.clip_model_name)
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)
end_time = time.time() end_time = time.time()
if not config.quiet: if not config.quiet:
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
def chain(
self,
image_features: torch.Tensor,
phrases: List[str],
best_prompt: str="",
best_sim: float=0,
max_count: int=32,
desc="Chaining",
reverse: bool=False
) -> str:
phrases = set(phrases)
if not best_prompt:
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
best_sim = self.similarity(image_features, best_prompt)
phrases.remove(best_prompt)
def check(addition: str) -> bool:
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if reverse:
sim = -sim
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse)
flave = best[len(best_prompt)+2:]
if not check(flave):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
break
phrases.remove(flave)
return best_prompt
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
if self.config.blip_offload: if self.config.blip_offload:
self.blip_model = self.blip_model.to(self.device) self.blip_model = self.blip_model.to(self.device)
@ -145,6 +222,8 @@ class Interrogator():
return image_features return image_features
def interrogate_classic(self, image: Image, max_flavors: int=3) -> str: def interrogate_classic(self, image: Image, max_flavors: int=3) -> str:
"""Classic mode creates a prompt in a standard format first describing the image,
then listing the artist, trending, movement, and flavor text modifiers."""
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
@ -162,69 +241,43 @@ class Interrogator():
return _truncate_to_fit(prompt, self.tokenize) return _truncate_to_fit(prompt, self.tokenize)
def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
"""Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode, but the prompts
are less readable."""
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, max_flavors) tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str:
"""Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2."""
image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count, reverse=True)
flaves = flaves + self.negative.labels
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain")
def interrogate(self, image: Image, max_flavors: int=32) -> str: def interrogate(self, image: Image, max_flavors: int=32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
best_medium = self.mediums.rank(image_features, 1)[0] flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_artist = self.artists.rank(image_features, 1)[0]
best_trending = self.trendings.rank(image_features, 1)[0]
best_movement = self.movements.rank(image_features, 1)[0]
best_prompt = caption best_prompt = caption
best_sim = self.similarity(image_features, best_prompt) best_sim = self.similarity(image_features, best_prompt)
def check(addition: str) -> bool: return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain")
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
def check_multi_batch(opts: List[str]):
nonlocal best_prompt, best_sim
prompts = []
for i in range(2**len(opts)):
prompt = best_prompt
for bit in range(len(opts)):
if i & (1 << bit):
prompt += ", " + opts[bit]
prompts.append(prompt)
t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
best_prompt = t.rank(image_features, 1)[0]
best_sim = self.similarity(image_features, best_prompt)
check_multi_batch([best_medium, best_artist, best_trending, best_movement])
extended_flavors = set(flaves)
for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
flave = best[len(best_prompt)+2:]
if not check(flave):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
break
extended_flavors.remove(flave)
return best_prompt
def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str: def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
if reverse:
similarity = -similarity
return text_array[similarity.argmax().item()] return text_array[similarity.argmax().item()]
def similarity(self, image_features: torch.Tensor, text: str) -> float: def similarity(self, image_features: torch.Tensor, text: str) -> float:
@ -235,6 +288,14 @@ class Interrogator():
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
return similarity[0][0].item() return similarity[0][0].item()
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
return similarity.T[0].tolist()
class LabelTable(): class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
@ -286,17 +347,19 @@ class LabelTable():
if self.device == 'cpu' or self.device == torch.device('cpu'): if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds] self.embeds = [e.astype(np.float32) for e in self.embeds]
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
top_count = min(top_count, len(text_embeds)) top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
similarity = image_features @ text_embeds.T similarity = image_features @ text_embeds.T
if reverse:
similarity = -similarity
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1) _, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)] return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]: def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]:
if len(self.labels) <= self.chunk_size: if len(self.labels) <= self.chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count) tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse)
return [self.labels[i] for i in tops] return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size)) num_chunks = int(math.ceil(len(self.labels)/self.chunk_size))
@ -306,7 +369,7 @@ class LabelTable():
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
start = chunk_idx*self.chunk_size start = chunk_idx*self.chunk_size
stop = min(start+self.chunk_size, len(self.embeds)) stop = min(start+self.chunk_size, len(self.embeds))
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse)
top_labels.extend([self.labels[start+i] for i in tops]) top_labels.extend([self.labels[start+i] for i in tops])
top_embeds.extend([self.embeds[start+i] for i in tops]) top_embeds.extend([self.embeds[start+i] for i in tops])
@ -314,6 +377,18 @@ class LabelTable():
return [top_labels[i] for i in tops] return [top_labels[i] for i in tops]
def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bool = False):
r = requests.get(url, stream=True)
file_size = int(r.headers.get("Content-Length", 0))
filename = url.split("/")[-1]
progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet)
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
progress.update(len(chunk))
progress.close()
def _load_list(data_path: str, filename: str) -> List[str]: def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()] items = [line.strip() for line in f.readlines()]

41
clip_interrogator/data/negative.txt

@ -0,0 +1,41 @@
3d
b&w
bad anatomy
bad art
blur
blurry
cartoon
childish
close up
deformed
disconnected limbs
disfigured
disgusting
extra limb
extra limbs
floating limbs
grain
illustration
kitsch
long body
long neck
low quality
low-res
malformed hands
mangled
missing limb
mutated
mutation
mutilated
noisy
old
out of focus
over saturation
oversaturated
poorly drawn
poorly drawn face
poorly drawn hands
render
surreal
ugly
weird colors

84
run_gradio.py

@ -3,7 +3,7 @@ import argparse
import gradio as gr import gradio as gr
import open_clip import open_clip
import torch import torch
from clip_interrogator import Interrogator, Config from clip_interrogator import Config, Interrogator
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link') parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
@ -14,40 +14,76 @@ if not torch.cuda.is_available():
ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) ci = Interrogator(Config(cache_path="cache", clip_model_path="cache"))
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): def image_analysis(image, clip_model_name):
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
image = image.convert('RGB')
image_features = ci.image_to_features(image)
top_mediums = ci.mediums.rank(image_features, 5)
top_artists = ci.artists.rank(image_features, 5)
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def image_to_prompt(image, mode, clip_model_name):
if clip_model_name != ci.config.clip_model_name: if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name ci.config.clip_model_name = clip_model_name
ci.load_clip_model() ci.load_clip_model()
ci.config.blip_max_length = int(blip_max_length)
ci.config.blip_num_beams = int(blip_num_beams)
image = image.convert('RGB') image = image.convert('RGB')
if mode == 'best': if mode == 'best':
return ci.interrogate(image) return ci.interrogate(image)
elif mode == 'classic': elif mode == 'classic':
return ci.interrogate_classic(image) return ci.interrogate_classic(image)
else: elif mode == 'fast':
return ci.interrogate_fast(image) return ci.interrogate_fast(image)
elif mode == 'negative':
return ci.interrogate_negative(image)
models = ['/'.join(x) for x in open_clip.list_pretrained()] models = ['/'.join(x) for x in open_clip.list_pretrained()]
inputs = [ def prompt_tab():
gr.inputs.Image(type='pil'), with gr.Column():
gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), with gr.Row():
gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'), image = gr.Image(type='pil', label="Image")
gr.Number(value=32, label='Caption Max Length'), with gr.Column():
gr.Number(value=64, label='Caption Num Beams'), mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
] model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
outputs = [ prompt = gr.Textbox(label="Prompt")
gr.outputs.Textbox(label="Output"), button = gr.Button("Generate prompt")
] button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt)
io = gr.Interface( def analyze_tab():
inference, with gr.Column():
inputs, with gr.Row():
outputs, image = gr.Image(type='pil', label="Image")
title="🕵 CLIP Interrogator 🕵", model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
allow_flagging=False, with gr.Row():
) medium = gr.Label(label="Medium", num_top_classes=5)
io.launch(share=args.share) artist = gr.Label(label="Artist", num_top_classes=5)
movement = gr.Label(label="Movement", num_top_classes=5)
trending = gr.Label(label="Trending", num_top_classes=5)
flavor = gr.Label(label="Flavor", num_top_classes=5)
button = gr.Button("Analyze")
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
with gr.Blocks() as ui:
gr.Markdown("# <center>🕵 CLIP Interrogator 🕵</center>")
with gr.Tab("Prompt"):
prompt_tab()
with gr.Tab("Analyze"):
analyze_tab()
ui.launch(show_api=False, debug=True, share=args.share)

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.3.5", version="0.4.0",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save