Browse Source

Bunch of updates! (#40)

- auto download the cache files from huggingface
- experimental negative prompt mode
- slight quality and performance improvement to best mode
- analyze tab in Colab and run_gradio to get table of ranked terms
pull/43/head
pharmapsychotic 2 years ago committed by GitHub
parent
commit
42b3cf4d9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      README.md
  2. 103
      clip_interrogator.ipynb
  3. 2
      clip_interrogator/__init__.py
  4. 175
      clip_interrogator/clip_interrogator.py
  5. 41
      clip_interrogator/data/negative.txt
  6. 84
      run_gradio.py
  7. 2
      setup.py

2
README.md

@ -36,7 +36,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator
pip install clip-interrogator==0.3.5
pip install clip-interrogator==0.4.0
```
You can then use it in your script

103
clip_interrogator.ipynb

@ -1,12 +1,13 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "3jm8RYrLqvzz"
},
"source": [
"# CLIP Interrogator 2.2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n",
@ -56,7 +57,6 @@
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'clip-interrogator'],\n",
" ['pip', 'install', 'git+https://github.com/pharmapsychotic/BLIP.git'],\n",
" ]\n",
" for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
@ -67,25 +67,6 @@
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"\n",
"print(\"Download preprocessed cache files...\")\n",
"CACHE_URLS = [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n",
"] if clip_model_name == 'ViT-L-14/openai' else [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
"]\n",
"os.makedirs('cache', exist_ok=True)\n",
"for url in CACHE_URLS:\n",
" print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"\n",
"import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n",
"\n",
@ -95,16 +76,37 @@
"config.clip_model_name = clip_model_name\n",
"ci = Interrogator(config)\n",
"\n",
"def inference(image, mode, best_max_flavors=32):\n",
"def image_analysis(image):\n",
" image = image.convert('RGB')\n",
" image_features = ci.image_to_features(image)\n",
"\n",
" top_mediums = ci.mediums.rank(image_features, 5)\n",
" top_artists = ci.artists.rank(image_features, 5)\n",
" top_movements = ci.movements.rank(image_features, 5)\n",
" top_trendings = ci.trendings.rank(image_features, 5)\n",
" top_flavors = ci.flavors.rank(image_features, 5)\n",
"\n",
" medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}\n",
" artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}\n",
" movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}\n",
" trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}\n",
" flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}\n",
" \n",
" return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks\n",
"\n",
"def image_to_prompt(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n",
" if mode == 'best':\n",
" return ci.interrogate(image, max_flavors=int(best_max_flavors))\n",
" return ci.interrogate(image)\n",
" elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n",
" else:\n",
" return ci.interrogate_fast(image)\n"
" elif mode == 'fast':\n",
" return ci.interrogate_fast(image)\n",
" elif mode == 'negative':\n",
" return ci.interrogate_negative(image)\n",
" "
]
},
{
@ -156,22 +158,36 @@
"source": [
"#@title Image to prompt! 🖼 -> 📝\n",
" \n",
"inputs = [\n",
" gr.inputs.Image(type='pil'),\n",
" gr.Radio(['best', 'fast'], label='', value='best'),\n",
" gr.Number(value=16, label='best mode max flavors'),\n",
"]\n",
"outputs = [\n",
" gr.outputs.Textbox(label=\"Output\"),\n",
"]\n",
"\n",
"io = gr.Interface(\n",
" inference, \n",
" inputs, \n",
" outputs, \n",
" allow_flagging=False,\n",
")\n",
"io.launch(debug=False)\n"
"def prompt_tab():\n",
" with gr.Column():\n",
" with gr.Row():\n",
" image = gr.Image(type='pil', label=\"Image\")\n",
" with gr.Column():\n",
" mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')\n",
" prompt = gr.Textbox(label=\"Prompt\")\n",
" button = gr.Button(\"Generate prompt\")\n",
" button.click(image_to_prompt, inputs=[image, mode], outputs=prompt)\n",
"\n",
"def analyze_tab():\n",
" with gr.Column():\n",
" with gr.Row():\n",
" image = gr.Image(type='pil', label=\"Image\")\n",
" with gr.Row():\n",
" medium = gr.Label(label=\"Medium\", num_top_classes=5)\n",
" artist = gr.Label(label=\"Artist\", num_top_classes=5) \n",
" movement = gr.Label(label=\"Movement\", num_top_classes=5)\n",
" trending = gr.Label(label=\"Trending\", num_top_classes=5)\n",
" flavor = gr.Label(label=\"Flavor\", num_top_classes=5)\n",
" button = gr.Button(\"Analyze\")\n",
" button.click(image_analysis, inputs=image, outputs=[medium, artist, movement, trending, flavor])\n",
"\n",
"with gr.Blocks() as ui:\n",
" with gr.Tab(\"Prompt\"):\n",
" prompt_tab()\n",
" with gr.Tab(\"Analyze\"):\n",
" analyze_tab()\n",
"\n",
"ui.launch(show_api=False, debug=False)\n"
]
},
{
@ -198,10 +214,9 @@
"from tqdm import tqdm\n",
"\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"prompt_mode = 'best' #@param [\"best\",\"fast\"]\n",
"prompt_mode = 'best' #@param [\"best\",\"fast\",\"classic\",\"negative\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n",
"best_max_flavors = 16 #@param {type:\"integer\"}\n",
"\n",
"\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
@ -218,7 +233,7 @@
" clear_output(wait=True)\n",
"\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n",
" prompt = image_to_prompt(image, prompt_mode)\n",
" prompts.append(prompt)\n",
"\n",
" print(prompt)\n",

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config
__version__ = '0.3.5'
__version__ = '0.4.0'
__author__ = 'pharmapsychotic'

175
clip_interrogator/clip_interrogator.py

@ -5,6 +5,7 @@ import numpy as np
import open_clip
import os
import pickle
import requests
import time
import torch
@ -21,6 +22,23 @@ BLIP_MODELS = {
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
}
CACHE_URLS_VITL = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',
]
CACHE_URLS_VITH = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
]
@dataclass
class Config:
# models can optionally be passed in directly
@ -40,13 +58,15 @@ class Config:
clip_model_path: str = None
# interrogator settings
cache_path: str = 'cache'
chunk_size: int = 2048
cache_path: str = 'cache' # path to store cached text embeddings
download_cache: bool = True # when true, cached embeds are downloaded from huggingface
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown
class Interrogator():
def __init__(self, config: Config):
self.config = config
@ -72,6 +92,21 @@ class Interrogator():
self.load_clip_model()
def download_cache(self, clip_model_name: str):
if clip_model_name == 'ViT-L-14/openai':
cache_urls = CACHE_URLS_VITL
elif clip_model_name == 'ViT-H-14/laion2b_s32b_b79k':
cache_urls = CACHE_URLS_VITH
else:
# text embeddings will be precomputed and cached locally
return
os.makedirs(self.config.cache_path, exist_ok=True)
for url in cache_urls:
filepath = os.path.join(self.config.cache_path, url.split('/')[-1])
if not os.path.exists(filepath):
_download_file(url, filepath, quiet=self.config.quiet)
def load_clip_model(self):
start_time = time.time()
config = self.config
@ -105,16 +140,58 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists])
self.download_cache(config.clip_model_name)
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)
end_time = time.time()
if not config.quiet:
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
def chain(
self,
image_features: torch.Tensor,
phrases: List[str],
best_prompt: str="",
best_sim: float=0,
max_count: int=32,
desc="Chaining",
reverse: bool=False
) -> str:
phrases = set(phrases)
if not best_prompt:
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
best_sim = self.similarity(image_features, best_prompt)
phrases.remove(best_prompt)
def check(addition: str) -> bool:
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if reverse:
sim = -sim
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse)
flave = best[len(best_prompt)+2:]
if not check(flave):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
break
phrases.remove(flave)
return best_prompt
def generate_caption(self, pil_image: Image) -> str:
if self.config.blip_offload:
self.blip_model = self.blip_model.to(self.device)
@ -145,6 +222,8 @@ class Interrogator():
return image_features
def interrogate_classic(self, image: Image, max_flavors: int=3) -> str:
"""Classic mode creates a prompt in a standard format first describing the image,
then listing the artist, trending, movement, and flavor text modifiers."""
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
@ -162,69 +241,43 @@ class Interrogator():
return _truncate_to_fit(prompt, self.tokenize)
def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
"""Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode, but the prompts
are less readable."""
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str:
"""Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2."""
image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count, reverse=True)
flaves = flaves + self.negative.labels
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain")
def interrogate(self, image: Image, max_flavors: int=32) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
best_medium = self.mediums.rank(image_features, 1)[0]
best_artist = self.artists.rank(image_features, 1)[0]
best_trending = self.trendings.rank(image_features, 1)[0]
best_movement = self.movements.rank(image_features, 1)[0]
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt = caption
best_sim = self.similarity(image_features, best_prompt)
def check(addition: str) -> bool:
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
def check_multi_batch(opts: List[str]):
nonlocal best_prompt, best_sim
prompts = []
for i in range(2**len(opts)):
prompt = best_prompt
for bit in range(len(opts)):
if i & (1 << bit):
prompt += ", " + opts[bit]
prompts.append(prompt)
t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
best_prompt = t.rank(image_features, 1)[0]
best_sim = self.similarity(image_features, best_prompt)
check_multi_batch([best_medium, best_artist, best_trending, best_movement])
extended_flavors = set(flaves)
for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
flave = best[len(best_prompt)+2:]
if not check(flave):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
break
extended_flavors.remove(flave)
return best_prompt
return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain")
def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str:
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
if reverse:
similarity = -similarity
return text_array[similarity.argmax().item()]
def similarity(self, image_features: torch.Tensor, text: str) -> float:
@ -235,6 +288,14 @@ class Interrogator():
similarity = text_features @ image_features.T
return similarity[0][0].item()
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
return similarity.T[0].tolist()
class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
@ -286,17 +347,19 @@ class LabelTable():
if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds]
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str:
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
with torch.cuda.amp.autocast():
similarity = image_features @ text_embeds.T
if reverse:
similarity = -similarity
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]:
def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]:
if len(self.labels) <= self.chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count)
tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse)
return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size))
@ -306,7 +369,7 @@ class LabelTable():
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
start = chunk_idx*self.chunk_size
stop = min(start+self.chunk_size, len(self.embeds))
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse)
top_labels.extend([self.labels[start+i] for i in tops])
top_embeds.extend([self.embeds[start+i] for i in tops])
@ -314,6 +377,18 @@ class LabelTable():
return [top_labels[i] for i in tops]
def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bool = False):
r = requests.get(url, stream=True)
file_size = int(r.headers.get("Content-Length", 0))
filename = url.split("/")[-1]
progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet)
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
progress.update(len(chunk))
progress.close()
def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]

41
clip_interrogator/data/negative.txt

@ -0,0 +1,41 @@
3d
b&w
bad anatomy
bad art
blur
blurry
cartoon
childish
close up
deformed
disconnected limbs
disfigured
disgusting
extra limb
extra limbs
floating limbs
grain
illustration
kitsch
long body
long neck
low quality
low-res
malformed hands
mangled
missing limb
mutated
mutation
mutilated
noisy
old
out of focus
over saturation
oversaturated
poorly drawn
poorly drawn face
poorly drawn hands
render
surreal
ugly
weird colors

84
run_gradio.py

@ -3,7 +3,7 @@ import argparse
import gradio as gr
import open_clip
import torch
from clip_interrogator import Interrogator, Config
from clip_interrogator import Config, Interrogator
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
@ -14,40 +14,76 @@ if not torch.cuda.is_available():
ci = Interrogator(Config(cache_path="cache", clip_model_path="cache"))
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
def image_analysis(image, clip_model_name):
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
image = image.convert('RGB')
image_features = ci.image_to_features(image)
top_mediums = ci.mediums.rank(image_features, 5)
top_artists = ci.artists.rank(image_features, 5)
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def image_to_prompt(image, mode, clip_model_name):
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
ci.config.blip_max_length = int(blip_max_length)
ci.config.blip_num_beams = int(blip_num_beams)
image = image.convert('RGB')
if mode == 'best':
return ci.interrogate(image)
elif mode == 'classic':
return ci.interrogate_classic(image)
else:
elif mode == 'fast':
return ci.interrogate_fast(image)
elif mode == 'negative':
return ci.interrogate_negative(image)
models = ['/'.join(x) for x in open_clip.list_pretrained()]
inputs = [
gr.inputs.Image(type='pil'),
gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'),
gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'),
gr.Number(value=32, label='Caption Max Length'),
gr.Number(value=64, label='Caption Num Beams'),
]
outputs = [
gr.outputs.Textbox(label="Output"),
]
io = gr.Interface(
inference,
inputs,
outputs,
title="🕵 CLIP Interrogator 🕵",
allow_flagging=False,
)
io.launch(share=args.share)
def prompt_tab():
with gr.Column():
with gr.Row():
image = gr.Image(type='pil', label="Image")
with gr.Column():
mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
prompt = gr.Textbox(label="Prompt")
button = gr.Button("Generate prompt")
button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt)
def analyze_tab():
with gr.Column():
with gr.Row():
image = gr.Image(type='pil', label="Image")
model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)
movement = gr.Label(label="Movement", num_top_classes=5)
trending = gr.Label(label="Trending", num_top_classes=5)
flavor = gr.Label(label="Flavor", num_top_classes=5)
button = gr.Button("Analyze")
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
with gr.Blocks() as ui:
gr.Markdown("# <center>🕵 CLIP Interrogator 🕵</center>")
with gr.Tab("Prompt"):
prompt_tab()
with gr.Tab("Analyze"):
analyze_tab()
ui.launch(show_api=False, debug=True, share=args.share)

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup(
name="clip-interrogator",
version="0.3.5",
version="0.4.0",
license='MIT',
author='pharmapsychotic',
author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save