Browse Source

Bunch of updates!

- auto download the cache files from huggingface
- experimental negative prompt mode
- slight quality and performance improvement to best mode
- analyze tab in Colab and run_gradio to get table of ranked terms
pull/40/head
pharmapsychotic 2 years ago
parent
commit
9e2f107309
  1. 135
      clip_interrogator.ipynb
  2. 109
      clip_interrogator/clip_interrogator.py
  3. 84
      run_gradio.py

135
clip_interrogator.ipynb

@ -7,13 +7,24 @@
"id": "3jm8RYrLqvzz" "id": "3jm8RYrLqvzz"
}, },
"source": [ "source": [
"# CLIP Interrogator 2.3 [negative prompt experiment!]\n", "# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n", "\n",
"This experimental version of CLIP Interrogator supports finding good \"negative\" prompts for Stable Diffusion 2. Note this is very *WIP* and more work needs to be done building out the dataset to support this (and perhaps a reverse BLIP) so for many images it may struggle to find a well aligned negative prompt. Alignments are displayed to help see how well it did.\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n", "\n",
"<br>\n", "<br>\n",
"\n", "\n",
"For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n" "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n",
"\n",
"This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n",
"\n",
"You can also run this on HuggingFace and Replicate<br>\n",
"[![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator)\n",
"\n",
"<br>\n",
"\n",
"If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
"\n",
"And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n"
] ]
}, },
{ {
@ -45,8 +56,7 @@
" install_cmds = [\n", " install_cmds = [\n",
" ['pip', 'install', 'gradio'],\n", " ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'git+https://github.com/pharmapsychotic/BLIP.git'],\n", " ['pip', 'install', 'clip-interrogator'],\n",
" ['git', 'clone', '-b', 'negative', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n", " ]\n",
" for cmd in install_cmds:\n", " for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
@ -54,29 +64,8 @@
"setup()\n", "setup()\n",
"\n", "\n",
"\n", "\n",
"clip_model_name = 'ViT-H-14/laion2b_s32b_b79k' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n", "clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"\n", "\n",
"print(\"Download preprocessed cache files...\")\n",
"CACHE_URLS = [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n",
"] if clip_model_name == 'ViT-L-14/openai' else [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
"]\n",
"os.makedirs('cache', exist_ok=True)\n",
"for url in CACHE_URLS:\n",
" print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"import sys\n",
"sys.path.append('clip-interrogator')\n",
"\n", "\n",
"import gradio as gr\n", "import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n", "from clip_interrogator import Config, Interrogator\n",
@ -87,24 +76,37 @@
"config.clip_model_name = clip_model_name\n", "config.clip_model_name = clip_model_name\n",
"ci = Interrogator(config)\n", "ci = Interrogator(config)\n",
"\n", "\n",
"def inference(image, mode):\n", "def image_analysis(image):\n",
" image = image.convert('RGB')\n",
" image_features = ci.image_to_features(image)\n",
"\n",
" top_mediums = ci.mediums.rank(image_features, 5)\n",
" top_artists = ci.artists.rank(image_features, 5)\n",
" top_movements = ci.movements.rank(image_features, 5)\n",
" top_trendings = ci.trendings.rank(image_features, 5)\n",
" top_flavors = ci.flavors.rank(image_features, 5)\n",
"\n",
" medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}\n",
" artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}\n",
" movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}\n",
" trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}\n",
" flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}\n",
" \n",
" return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks\n",
"\n",
"def image_to_prompt(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n", " image = image.convert('RGB')\n",
" prompt = \"\"\n",
" if mode == 'best':\n", " if mode == 'best':\n",
" prompt = ci.interrogate(image)\n", " return ci.interrogate(image)\n",
" elif mode == 'classic':\n", " elif mode == 'classic':\n",
" prompt = ci.interrogate_classic(image)\n", " return ci.interrogate_classic(image)\n",
" elif mode == 'fast':\n", " elif mode == 'fast':\n",
" prompt = ci.interrogate_fast(image)\n", " return ci.interrogate_fast(image)\n",
" elif mode == 'negative':\n", " elif mode == 'negative':\n",
" image_features = ci.image_to_features(image)\n", " return ci.interrogate_negative(image)\n",
" flaves = ci.flavors.rank(image_features, ci.config.flavor_intermediate_count, reverse=True)\n", " "
" flaves = flaves + ci.negative.labels\n",
" prompt = ci.chain(image_features, flaves, max_count=32, reverse=True, desc=\"Negative chain\")\n",
" sim = ci.similarity(ci.image_to_features(image), prompt)\n",
" return prompt, sim"
] ]
}, },
{ {
@ -156,22 +158,36 @@
"source": [ "source": [
"#@title Image to prompt! 🖼 -> 📝\n", "#@title Image to prompt! 🖼 -> 📝\n",
" \n", " \n",
"inputs = [\n", "def prompt_tab():\n",
" gr.inputs.Image(type='pil'),\n", " with gr.Column():\n",
" gr.Radio(['best', 'fast', 'negative'], label='Mode', value='best'),\n", " with gr.Row():\n",
"]\n", " image = gr.Image(type='pil', label=\"Image\")\n",
"outputs = [\n", " with gr.Column():\n",
" gr.outputs.Textbox(label=\"Output\"),\n", " mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')\n",
" gr.Number(label=\"Alignment\"),\n", " prompt = gr.Textbox(label=\"Prompt\")\n",
"]\n", " button = gr.Button(\"Generate prompt\")\n",
"\n", " button.click(image_to_prompt, inputs=[image, mode], outputs=prompt)\n",
"io = gr.Interface(\n", "\n",
" inference, \n", "def analyze_tab():\n",
" inputs, \n", " with gr.Column():\n",
" outputs, \n", " with gr.Row():\n",
" allow_flagging=False,\n", " image = gr.Image(type='pil', label=\"Image\")\n",
")\n", " with gr.Row():\n",
"io.launch(debug=False)\n" " medium = gr.Label(label=\"Medium\", num_top_classes=5)\n",
" artist = gr.Label(label=\"Artist\", num_top_classes=5) \n",
" movement = gr.Label(label=\"Movement\", num_top_classes=5)\n",
" trending = gr.Label(label=\"Trending\", num_top_classes=5)\n",
" flavor = gr.Label(label=\"Flavor\", num_top_classes=5)\n",
" button = gr.Button(\"Analyze\")\n",
" button.click(image_analysis, inputs=image, outputs=[medium, artist, movement, trending, flavor])\n",
"\n",
"with gr.Blocks() as ui:\n",
" with gr.Tab(\"Prompt\"):\n",
" prompt_tab()\n",
" with gr.Tab(\"Analyze\"):\n",
" analyze_tab()\n",
"\n",
"ui.launch(show_api=False, debug=False)\n"
] ]
}, },
{ {
@ -198,10 +214,9 @@
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"prompt_mode = 'best' #@param [\"best\",\"fast\"]\n", "prompt_mode = 'best' #@param [\"best\",\"fast\",\"classic\",\"negative\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n", "max_filename_len = 128 #@param {type:\"integer\"}\n",
"best_max_flavors = 16 #@param {type:\"integer\"}\n",
"\n", "\n",
"\n", "\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
@ -218,7 +233,7 @@
" clear_output(wait=True)\n", " clear_output(wait=True)\n",
"\n", "\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n", " prompt = image_to_prompt(image, prompt_mode)\n",
" prompts.append(prompt)\n", " prompts.append(prompt)\n",
"\n", "\n",
" print(prompt)\n", " print(prompt)\n",
@ -261,7 +276,7 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "ci", "display_name": "Python 3.7.15 ('py37')",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -275,12 +290,12 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.10 (default, Nov 14 2022, 12:59:47) \n[GCC 9.4.0]" "version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "90daa5087f97972f35e673cab20894a33c1e0ca77092ccdd163e60b53596983a" "hash": "1f51d5616d3bc2b87a82685314c5be1ec9a49b6e0cb1f707bfa2acb6c45f3e5f"
} }
} }
}, },

109
clip_interrogator/clip_interrogator.py

@ -5,6 +5,7 @@ import numpy as np
import open_clip import open_clip
import os import os
import pickle import pickle
import requests
import time import time
import torch import torch
@ -21,6 +22,23 @@ BLIP_MODELS = {
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
} }
CACHE_URLS_VITL = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',
]
CACHE_URLS_VITH = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
]
@dataclass @dataclass
class Config: class Config:
# models can optionally be passed in directly # models can optionally be passed in directly
@ -40,13 +58,15 @@ class Config:
clip_model_path: str = None clip_model_path: str = None
# interrogator settings # interrogator settings
cache_path: str = 'cache' cache_path: str = 'cache' # path to store cached text embeddings
chunk_size: int = 2048 download_cache: bool = True # when true, cached embeds are downloaded from huggingface
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), 'data') data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
flavor_intermediate_count: int = 2048 flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown quiet: bool = False # when quiet progress bars are not shown
class Interrogator(): class Interrogator():
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
@ -72,6 +92,21 @@ class Interrogator():
self.load_clip_model() self.load_clip_model()
def download_cache(self, clip_model_name: str):
if clip_model_name == 'ViT-L-14/openai':
cache_urls = CACHE_URLS_VITL
elif clip_model_name == 'ViT-H-14/laion2b_s32b_b79k':
cache_urls = CACHE_URLS_VITH
else:
# text embeddings will be precomputed and cached locally
return
os.makedirs(self.config.cache_path, exist_ok=True)
for url in cache_urls:
filepath = os.path.join(self.config.cache_path, url.split('/')[-1])
if not os.path.exists(filepath):
_download_file(url, filepath, quiet=self.config.quiet)
def load_clip_model(self): def load_clip_model(self):
start_time = time.time() start_time = time.time()
config = self.config config = self.config
@ -105,6 +140,8 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists] artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists]) artists.extend([f"inspired by {a}" for a in raw_artists])
self.download_cache(config.clip_model_name)
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
@ -185,6 +222,8 @@ class Interrogator():
return image_features return image_features
def interrogate_classic(self, image: Image, max_flavors: int=3) -> str: def interrogate_classic(self, image: Image, max_flavors: int=3) -> str:
"""Classic mode creates a prompt in a standard format first describing the image,
then listing the artist, trending, movement, and flavor text modifiers."""
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
@ -202,58 +241,40 @@ class Interrogator():
return _truncate_to_fit(prompt, self.tokenize) return _truncate_to_fit(prompt, self.tokenize)
def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
"""Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode, but the prompts
are less readable."""
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, max_flavors) tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str:
"""Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2."""
image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count, reverse=True)
flaves = flaves + self.negative.labels
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain")
def interrogate(self, image: Image, max_flavors: int=32) -> str: def interrogate(self, image: Image, max_flavors: int=32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
best_medium = self.mediums.rank(image_features, 1)[0] flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_artist = self.artists.rank(image_features, 1)[0]
best_trending = self.trendings.rank(image_features, 1)[0]
best_movement = self.movements.rank(image_features, 1)[0]
best_prompt = caption best_prompt = caption
best_sim = self.similarity(image_features, best_prompt) best_sim = self.similarity(image_features, best_prompt)
def check(addition: str) -> bool:
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
def check_multi_batch(opts: List[str]):
nonlocal best_prompt, best_sim
prompts = []
for i in range(2**len(opts)):
prompt = best_prompt
for bit in range(len(opts)):
if i & (1 << bit):
prompt += ", " + opts[bit]
prompts.append(prompt)
t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
best_prompt = t.rank(image_features, 1)[0]
best_sim = self.similarity(image_features, best_prompt)
check_multi_batch([best_medium, best_artist, best_trending, best_movement])
return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain") return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain")
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
if reverse: if reverse:
similarity = -similarity similarity = -similarity
@ -267,6 +288,14 @@ class Interrogator():
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
return similarity[0][0].item() return similarity[0][0].item()
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
return similarity.T[0].tolist()
class LabelTable(): class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
@ -348,6 +377,18 @@ class LabelTable():
return [top_labels[i] for i in tops] return [top_labels[i] for i in tops]
def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bool = False):
r = requests.get(url, stream=True)
file_size = int(r.headers.get("Content-Length", 0))
filename = url.split("/")[-1]
progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet)
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
progress.update(len(chunk))
progress.close()
def _load_list(data_path: str, filename: str) -> List[str]: def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()] items = [line.strip() for line in f.readlines()]

84
run_gradio.py

@ -3,7 +3,7 @@ import argparse
import gradio as gr import gradio as gr
import open_clip import open_clip
import torch import torch
from clip_interrogator import Interrogator, Config from clip_interrogator import Config, Interrogator
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link') parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
@ -14,40 +14,76 @@ if not torch.cuda.is_available():
ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) ci = Interrogator(Config(cache_path="cache", clip_model_path="cache"))
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): def image_analysis(image, clip_model_name):
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
image = image.convert('RGB')
image_features = ci.image_to_features(image)
top_mediums = ci.mediums.rank(image_features, 5)
top_artists = ci.artists.rank(image_features, 5)
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def image_to_prompt(image, mode, clip_model_name):
if clip_model_name != ci.config.clip_model_name: if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name ci.config.clip_model_name = clip_model_name
ci.load_clip_model() ci.load_clip_model()
ci.config.blip_max_length = int(blip_max_length)
ci.config.blip_num_beams = int(blip_num_beams)
image = image.convert('RGB') image = image.convert('RGB')
if mode == 'best': if mode == 'best':
return ci.interrogate(image) return ci.interrogate(image)
elif mode == 'classic': elif mode == 'classic':
return ci.interrogate_classic(image) return ci.interrogate_classic(image)
else: elif mode == 'fast':
return ci.interrogate_fast(image) return ci.interrogate_fast(image)
elif mode == 'negative':
return ci.interrogate_negative(image)
models = ['/'.join(x) for x in open_clip.list_pretrained()] models = ['/'.join(x) for x in open_clip.list_pretrained()]
inputs = [ def prompt_tab():
gr.inputs.Image(type='pil'), with gr.Column():
gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), with gr.Row():
gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'), image = gr.Image(type='pil', label="Image")
gr.Number(value=32, label='Caption Max Length'), with gr.Column():
gr.Number(value=64, label='Caption Num Beams'), mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
] model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
outputs = [ prompt = gr.Textbox(label="Prompt")
gr.outputs.Textbox(label="Output"), button = gr.Button("Generate prompt")
] button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt)
io = gr.Interface( def analyze_tab():
inference, with gr.Column():
inputs, with gr.Row():
outputs, image = gr.Image(type='pil', label="Image")
title="🕵 CLIP Interrogator 🕵", model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
allow_flagging=False, with gr.Row():
) medium = gr.Label(label="Medium", num_top_classes=5)
io.launch(share=args.share) artist = gr.Label(label="Artist", num_top_classes=5)
movement = gr.Label(label="Movement", num_top_classes=5)
trending = gr.Label(label="Trending", num_top_classes=5)
flavor = gr.Label(label="Flavor", num_top_classes=5)
button = gr.Button("Analyze")
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
with gr.Blocks() as ui:
gr.Markdown("# <center>🕵 CLIP Interrogator 🕵</center>")
with gr.Tab("Prompt"):
prompt_tab()
with gr.Tab("Analyze"):
analyze_tab()
ui.launch(show_api=False, debug=True, share=args.share)

Loading…
Cancel
Save