diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb
old mode 100644
new mode 100755
index 9c81f51..96dec9e
--- a/clip_interrogator.ipynb
+++ b/clip_interrogator.ipynb
@@ -7,13 +7,24 @@
"id": "3jm8RYrLqvzz"
},
"source": [
- "# CLIP Interrogator 2.3 [negative prompt experiment!]\n",
+ "# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
- "This experimental version of CLIP Interrogator supports finding good \"negative\" prompts for Stable Diffusion 2. Note this is very *WIP* and more work needs to be done building out the dataset to support this (and perhaps a reverse BLIP) so for many images it may struggle to find a well aligned negative prompt. Alignments are displayed to help see how well it did.\n",
+ "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n",
"
\n",
"\n",
- "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n"
+ "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n",
+ "\n",
+ "This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n",
+ "\n",
+ "You can also run this on HuggingFace and Replicate
\n",
+ "[](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [](https://replicate.com/pharmapsychotic/clip-interrogator)\n",
+ "\n",
+ "
\n",
+ "\n",
+ "If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
+ "\n",
+ "And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n"
]
},
{
@@ -45,8 +56,7 @@
" install_cmds = [\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
- " ['pip', 'install', 'git+https://github.com/pharmapsychotic/BLIP.git'],\n",
- " ['git', 'clone', '-b', 'negative', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
+ " ['pip', 'install', 'clip-interrogator'],\n",
" ]\n",
" for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
@@ -54,29 +64,8 @@
"setup()\n",
"\n",
"\n",
- "clip_model_name = 'ViT-H-14/laion2b_s32b_b79k' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
- "\n",
- "\n",
- "print(\"Download preprocessed cache files...\")\n",
- "CACHE_URLS = [\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n",
- "] if clip_model_name == 'ViT-L-14/openai' else [\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
- " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
- "]\n",
- "os.makedirs('cache', exist_ok=True)\n",
- "for url in CACHE_URLS:\n",
- " print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
+ "clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
- "import sys\n",
- "sys.path.append('clip-interrogator')\n",
"\n",
"import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n",
@@ -87,24 +76,37 @@
"config.clip_model_name = clip_model_name\n",
"ci = Interrogator(config)\n",
"\n",
- "def inference(image, mode):\n",
+ "def image_analysis(image):\n",
+ " image = image.convert('RGB')\n",
+ " image_features = ci.image_to_features(image)\n",
+ "\n",
+ " top_mediums = ci.mediums.rank(image_features, 5)\n",
+ " top_artists = ci.artists.rank(image_features, 5)\n",
+ " top_movements = ci.movements.rank(image_features, 5)\n",
+ " top_trendings = ci.trendings.rank(image_features, 5)\n",
+ " top_flavors = ci.flavors.rank(image_features, 5)\n",
+ "\n",
+ " medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}\n",
+ " artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}\n",
+ " movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}\n",
+ " trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}\n",
+ " flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}\n",
+ " \n",
+ " return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks\n",
+ "\n",
+ "def image_to_prompt(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n",
- " prompt = \"\"\n",
" if mode == 'best':\n",
- " prompt = ci.interrogate(image)\n",
+ " return ci.interrogate(image)\n",
" elif mode == 'classic':\n",
- " prompt = ci.interrogate_classic(image)\n",
+ " return ci.interrogate_classic(image)\n",
" elif mode == 'fast':\n",
- " prompt = ci.interrogate_fast(image)\n",
+ " return ci.interrogate_fast(image)\n",
" elif mode == 'negative':\n",
- " image_features = ci.image_to_features(image)\n",
- " flaves = ci.flavors.rank(image_features, ci.config.flavor_intermediate_count, reverse=True)\n",
- " flaves = flaves + ci.negative.labels\n",
- " prompt = ci.chain(image_features, flaves, max_count=32, reverse=True, desc=\"Negative chain\")\n",
- " sim = ci.similarity(ci.image_to_features(image), prompt)\n",
- " return prompt, sim"
+ " return ci.interrogate_negative(image)\n",
+ " "
]
},
{
@@ -156,22 +158,36 @@
"source": [
"#@title Image to prompt! 🖼️ -> 📝\n",
" \n",
- "inputs = [\n",
- " gr.inputs.Image(type='pil'),\n",
- " gr.Radio(['best', 'fast', 'negative'], label='Mode', value='best'),\n",
- "]\n",
- "outputs = [\n",
- " gr.outputs.Textbox(label=\"Output\"),\n",
- " gr.Number(label=\"Alignment\"),\n",
- "]\n",
- "\n",
- "io = gr.Interface(\n",
- " inference, \n",
- " inputs, \n",
- " outputs, \n",
- " allow_flagging=False,\n",
- ")\n",
- "io.launch(debug=False)\n"
+ "def prompt_tab():\n",
+ " with gr.Column():\n",
+ " with gr.Row():\n",
+ " image = gr.Image(type='pil', label=\"Image\")\n",
+ " with gr.Column():\n",
+ " mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')\n",
+ " prompt = gr.Textbox(label=\"Prompt\")\n",
+ " button = gr.Button(\"Generate prompt\")\n",
+ " button.click(image_to_prompt, inputs=[image, mode], outputs=prompt)\n",
+ "\n",
+ "def analyze_tab():\n",
+ " with gr.Column():\n",
+ " with gr.Row():\n",
+ " image = gr.Image(type='pil', label=\"Image\")\n",
+ " with gr.Row():\n",
+ " medium = gr.Label(label=\"Medium\", num_top_classes=5)\n",
+ " artist = gr.Label(label=\"Artist\", num_top_classes=5) \n",
+ " movement = gr.Label(label=\"Movement\", num_top_classes=5)\n",
+ " trending = gr.Label(label=\"Trending\", num_top_classes=5)\n",
+ " flavor = gr.Label(label=\"Flavor\", num_top_classes=5)\n",
+ " button = gr.Button(\"Analyze\")\n",
+ " button.click(image_analysis, inputs=image, outputs=[medium, artist, movement, trending, flavor])\n",
+ "\n",
+ "with gr.Blocks() as ui:\n",
+ " with gr.Tab(\"Prompt\"):\n",
+ " prompt_tab()\n",
+ " with gr.Tab(\"Analyze\"):\n",
+ " analyze_tab()\n",
+ "\n",
+ "ui.launch(show_api=False, debug=False)\n"
]
},
{
@@ -198,10 +214,9 @@
"from tqdm import tqdm\n",
"\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
- "prompt_mode = 'best' #@param [\"best\",\"fast\"]\n",
+ "prompt_mode = 'best' #@param [\"best\",\"fast\",\"classic\",\"negative\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n",
- "best_max_flavors = 16 #@param {type:\"integer\"}\n",
"\n",
"\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
@@ -218,7 +233,7 @@
" clear_output(wait=True)\n",
"\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
- " prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n",
+ " prompt = image_to_prompt(image, prompt_mode)\n",
" prompts.append(prompt)\n",
"\n",
" print(prompt)\n",
@@ -261,7 +276,7 @@
"provenance": []
},
"kernelspec": {
- "display_name": "ci",
+ "display_name": "Python 3.7.15 ('py37')",
"language": "python",
"name": "python3"
},
@@ -275,12 +290,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.10 (default, Nov 14 2022, 12:59:47) \n[GCC 9.4.0]"
+ "version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
- "hash": "90daa5087f97972f35e673cab20894a33c1e0ca77092ccdd163e60b53596983a"
+ "hash": "1f51d5616d3bc2b87a82685314c5be1ec9a49b6e0cb1f707bfa2acb6c45f3e5f"
}
}
},
diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py
index 0944dc8..db212d2 100644
--- a/clip_interrogator/clip_interrogator.py
+++ b/clip_interrogator/clip_interrogator.py
@@ -5,6 +5,7 @@ import numpy as np
import open_clip
import os
import pickle
+import requests
import time
import torch
@@ -21,6 +22,23 @@ BLIP_MODELS = {
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
}
+CACHE_URLS_VITL = [
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',
+]
+
+CACHE_URLS_VITH = [
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
+]
+
+
@dataclass
class Config:
# models can optionally be passed in directly
@@ -40,13 +58,15 @@ class Config:
clip_model_path: str = None
# interrogator settings
- cache_path: str = 'cache'
- chunk_size: int = 2048
+ cache_path: str = 'cache' # path to store cached text embeddings
+ download_cache: bool = True # when true, cached embeds are downloaded from huggingface
+ chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown
+
class Interrogator():
def __init__(self, config: Config):
self.config = config
@@ -72,6 +92,21 @@ class Interrogator():
self.load_clip_model()
+ def download_cache(self, clip_model_name: str):
+ if clip_model_name == 'ViT-L-14/openai':
+ cache_urls = CACHE_URLS_VITL
+ elif clip_model_name == 'ViT-H-14/laion2b_s32b_b79k':
+ cache_urls = CACHE_URLS_VITH
+ else:
+ # text embeddings will be precomputed and cached locally
+ return
+
+ os.makedirs(self.config.cache_path, exist_ok=True)
+ for url in cache_urls:
+ filepath = os.path.join(self.config.cache_path, url.split('/')[-1])
+ if not os.path.exists(filepath):
+ _download_file(url, filepath, quiet=self.config.quiet)
+
def load_clip_model(self):
start_time = time.time()
config = self.config
@@ -105,6 +140,8 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists])
+ self.download_cache(config.clip_model_name)
+
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
@@ -185,6 +222,8 @@ class Interrogator():
return image_features
def interrogate_classic(self, image: Image, max_flavors: int=3) -> str:
+ """Classic mode creates a prompt in a standard format first describing the image,
+ then listing the artist, trending, movement, and flavor text modifiers."""
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
@@ -202,58 +241,40 @@ class Interrogator():
return _truncate_to_fit(prompt, self.tokenize)
def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
+ """Fast mode simply adds the top ranked terms after a caption. It generally results in
+ better similarity between generated prompt and image than classic mode, but the prompts
+ are less readable."""
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
+ def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str:
+ """Negative mode chains together the most dissimilar terms to the image. It can be used
+ to help build a negative prompt to pair with the regular positive prompt and often
+ improve the results of generated images particularly with Stable Diffusion 2."""
+ image_features = self.image_to_features(image)
+ flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count, reverse=True)
+ flaves = flaves + self.negative.labels
+ return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain")
+
def interrogate(self, image: Image, max_flavors: int=32) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
- flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
- best_medium = self.mediums.rank(image_features, 1)[0]
- best_artist = self.artists.rank(image_features, 1)[0]
- best_trending = self.trendings.rank(image_features, 1)[0]
- best_movement = self.movements.rank(image_features, 1)[0]
+ merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
+ flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt = caption
best_sim = self.similarity(image_features, best_prompt)
- def check(addition: str) -> bool:
- nonlocal best_prompt, best_sim
- prompt = best_prompt + ", " + addition
- sim = self.similarity(image_features, prompt)
- if sim > best_sim:
- best_sim = sim
- best_prompt = prompt
- return True
- return False
-
- def check_multi_batch(opts: List[str]):
- nonlocal best_prompt, best_sim
- prompts = []
- for i in range(2**len(opts)):
- prompt = best_prompt
- for bit in range(len(opts)):
- if i & (1 << bit):
- prompt += ", " + opts[bit]
- prompts.append(prompt)
-
- t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
- best_prompt = t.rank(image_features, 1)[0]
- best_sim = self.similarity(image_features, best_prompt)
-
- check_multi_batch([best_medium, best_artist, best_trending, best_movement])
-
return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain")
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
- text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
if reverse:
similarity = -similarity
@@ -267,6 +288,14 @@ class Interrogator():
similarity = text_features @ image_features.T
return similarity[0][0].item()
+ def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]:
+ text_tokens = self.tokenize([text for text in text_array]).to(self.device)
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ text_features = self.clip_model.encode_text(text_tokens)
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+ similarity = text_features @ image_features.T
+ return similarity.T[0].tolist()
+
class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
@@ -348,6 +377,18 @@ class LabelTable():
return [top_labels[i] for i in tops]
+def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bool = False):
+ r = requests.get(url, stream=True)
+ file_size = int(r.headers.get("Content-Length", 0))
+ filename = url.split("/")[-1]
+ progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet)
+ with open(filepath, "wb") as f:
+ for chunk in r.iter_content(chunk_size=chunk_size):
+ if chunk:
+ f.write(chunk)
+ progress.update(len(chunk))
+ progress.close()
+
def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
diff --git a/run_gradio.py b/run_gradio.py
index 9fc685f..c8f1597 100755
--- a/run_gradio.py
+++ b/run_gradio.py
@@ -3,7 +3,7 @@ import argparse
import gradio as gr
import open_clip
import torch
-from clip_interrogator import Interrogator, Config
+from clip_interrogator import Config, Interrogator
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
@@ -14,40 +14,76 @@ if not torch.cuda.is_available():
ci = Interrogator(Config(cache_path="cache", clip_model_path="cache"))
-def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
+def image_analysis(image, clip_model_name):
+ if clip_model_name != ci.config.clip_model_name:
+ ci.config.clip_model_name = clip_model_name
+ ci.load_clip_model()
+
+ image = image.convert('RGB')
+ image_features = ci.image_to_features(image)
+
+ top_mediums = ci.mediums.rank(image_features, 5)
+ top_artists = ci.artists.rank(image_features, 5)
+ top_movements = ci.movements.rank(image_features, 5)
+ top_trendings = ci.trendings.rank(image_features, 5)
+ top_flavors = ci.flavors.rank(image_features, 5)
+
+ medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
+ artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
+ movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
+ trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
+ flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
+
+ return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
+
+def image_to_prompt(image, mode, clip_model_name):
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
- ci.config.blip_max_length = int(blip_max_length)
- ci.config.blip_num_beams = int(blip_num_beams)
image = image.convert('RGB')
if mode == 'best':
return ci.interrogate(image)
elif mode == 'classic':
return ci.interrogate_classic(image)
- else:
+ elif mode == 'fast':
return ci.interrogate_fast(image)
+ elif mode == 'negative':
+ return ci.interrogate_negative(image)
+
models = ['/'.join(x) for x in open_clip.list_pretrained()]
-inputs = [
- gr.inputs.Image(type='pil'),
- gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'),
- gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'),
- gr.Number(value=32, label='Caption Max Length'),
- gr.Number(value=64, label='Caption Num Beams'),
-]
-outputs = [
- gr.outputs.Textbox(label="Output"),
-]
-
-io = gr.Interface(
- inference,
- inputs,
- outputs,
- title="🕵️♂️ CLIP Interrogator 🕵️♂️",
- allow_flagging=False,
-)
-io.launch(share=args.share)
+def prompt_tab():
+ with gr.Column():
+ with gr.Row():
+ image = gr.Image(type='pil', label="Image")
+ with gr.Column():
+ mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
+ model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
+ prompt = gr.Textbox(label="Prompt")
+ button = gr.Button("Generate prompt")
+ button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt)
+
+def analyze_tab():
+ with gr.Column():
+ with gr.Row():
+ image = gr.Image(type='pil', label="Image")
+ model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
+ with gr.Row():
+ medium = gr.Label(label="Medium", num_top_classes=5)
+ artist = gr.Label(label="Artist", num_top_classes=5)
+ movement = gr.Label(label="Movement", num_top_classes=5)
+ trending = gr.Label(label="Trending", num_top_classes=5)
+ flavor = gr.Label(label="Flavor", num_top_classes=5)
+ button = gr.Button("Analyze")
+ button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
+
+with gr.Blocks() as ui:
+ gr.Markdown("#