diff --git a/README.md b/README.md index 0ffda8a..9d7f0e1 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Install with PIP pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 # install clip-interrogator -pip install clip-interrogator==0.3.5 +pip install clip-interrogator==0.4.0 ``` You can then use it in your script diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb old mode 100644 new mode 100755 index 99bf7fe..96dec9e --- a/clip_interrogator.ipynb +++ b/clip_interrogator.ipynb @@ -1,12 +1,13 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "3jm8RYrLqvzz" }, "source": [ - "# CLIP Interrogator 2.2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", + "# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", "\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "\n", @@ -56,7 +57,6 @@ " ['pip', 'install', 'gradio'],\n", " ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', 'clip-interrogator'],\n", - " ['pip', 'install', 'git+https://github.com/pharmapsychotic/BLIP.git'],\n", " ]\n", " for cmd in install_cmds:\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", @@ -67,25 +67,6 @@ "clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n", "\n", "\n", - "print(\"Download preprocessed cache files...\")\n", - "CACHE_URLS = [\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n", - "] if clip_model_name == 'ViT-L-14/openai' else [\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n", - " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n", - "]\n", - "os.makedirs('cache', exist_ok=True)\n", - "for url in CACHE_URLS:\n", - " print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", - "\n", - "\n", "import gradio as gr\n", "from clip_interrogator import Config, Interrogator\n", "\n", @@ -95,16 +76,37 @@ "config.clip_model_name = clip_model_name\n", "ci = Interrogator(config)\n", "\n", - "def inference(image, mode, best_max_flavors=32):\n", + "def image_analysis(image):\n", + " image = image.convert('RGB')\n", + " image_features = ci.image_to_features(image)\n", + "\n", + " top_mediums = ci.mediums.rank(image_features, 5)\n", + " top_artists = ci.artists.rank(image_features, 5)\n", + " top_movements = ci.movements.rank(image_features, 5)\n", + " top_trendings = ci.trendings.rank(image_features, 5)\n", + " top_flavors = ci.flavors.rank(image_features, 5)\n", + "\n", + " medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}\n", + " artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}\n", + " movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}\n", + " trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}\n", + " flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}\n", + " \n", + " return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks\n", + "\n", + "def image_to_prompt(image, mode):\n", " ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n", " image = image.convert('RGB')\n", " if mode == 'best':\n", - " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", + " return ci.interrogate(image)\n", " elif mode == 'classic':\n", " return ci.interrogate_classic(image)\n", - " else:\n", - " return ci.interrogate_fast(image)\n" + " elif mode == 'fast':\n", + " return ci.interrogate_fast(image)\n", + " elif mode == 'negative':\n", + " return ci.interrogate_negative(image)\n", + " " ] }, { @@ -156,22 +158,36 @@ "source": [ "#@title Image to prompt! 🖼️ -> 📝\n", " \n", - "inputs = [\n", - " gr.inputs.Image(type='pil'),\n", - " gr.Radio(['best', 'fast'], label='', value='best'),\n", - " gr.Number(value=16, label='best mode max flavors'),\n", - "]\n", - "outputs = [\n", - " gr.outputs.Textbox(label=\"Output\"),\n", - "]\n", - "\n", - "io = gr.Interface(\n", - " inference, \n", - " inputs, \n", - " outputs, \n", - " allow_flagging=False,\n", - ")\n", - "io.launch(debug=False)\n" + "def prompt_tab():\n", + " with gr.Column():\n", + " with gr.Row():\n", + " image = gr.Image(type='pil', label=\"Image\")\n", + " with gr.Column():\n", + " mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')\n", + " prompt = gr.Textbox(label=\"Prompt\")\n", + " button = gr.Button(\"Generate prompt\")\n", + " button.click(image_to_prompt, inputs=[image, mode], outputs=prompt)\n", + "\n", + "def analyze_tab():\n", + " with gr.Column():\n", + " with gr.Row():\n", + " image = gr.Image(type='pil', label=\"Image\")\n", + " with gr.Row():\n", + " medium = gr.Label(label=\"Medium\", num_top_classes=5)\n", + " artist = gr.Label(label=\"Artist\", num_top_classes=5) \n", + " movement = gr.Label(label=\"Movement\", num_top_classes=5)\n", + " trending = gr.Label(label=\"Trending\", num_top_classes=5)\n", + " flavor = gr.Label(label=\"Flavor\", num_top_classes=5)\n", + " button = gr.Button(\"Analyze\")\n", + " button.click(image_analysis, inputs=image, outputs=[medium, artist, movement, trending, flavor])\n", + "\n", + "with gr.Blocks() as ui:\n", + " with gr.Tab(\"Prompt\"):\n", + " prompt_tab()\n", + " with gr.Tab(\"Analyze\"):\n", + " analyze_tab()\n", + "\n", + "ui.launch(show_api=False, debug=False)\n" ] }, { @@ -198,10 +214,9 @@ "from tqdm import tqdm\n", "\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", - "prompt_mode = 'best' #@param [\"best\",\"fast\"]\n", + "prompt_mode = 'best' #@param [\"best\",\"fast\",\"classic\",\"negative\"]\n", "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", "max_filename_len = 128 #@param {type:\"integer\"}\n", - "best_max_flavors = 16 #@param {type:\"integer\"}\n", "\n", "\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", @@ -218,7 +233,7 @@ " clear_output(wait=True)\n", "\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", - " prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n", + " prompt = image_to_prompt(image, prompt_mode)\n", " prompts.append(prompt)\n", "\n", " print(prompt)\n", diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 7e92186..04bd77c 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ from .clip_interrogator import Interrogator, Config -__version__ = '0.3.5' +__version__ = '0.4.0' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index a634436..db212d2 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -5,6 +5,7 @@ import numpy as np import open_clip import os import pickle +import requests import time import torch @@ -21,6 +22,23 @@ BLIP_MODELS = { 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' } +CACHE_URLS_VITL = [ + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl', +] + +CACHE_URLS_VITH = [ + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl', +] + + @dataclass class Config: # models can optionally be passed in directly @@ -40,13 +58,15 @@ class Config: clip_model_path: str = None # interrogator settings - cache_path: str = 'cache' - chunk_size: int = 2048 + cache_path: str = 'cache' # path to store cached text embeddings + download_cache: bool = True # when true, cached embeds are downloaded from huggingface + chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM data_path: str = os.path.join(os.path.dirname(__file__), 'data') device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") flavor_intermediate_count: int = 2048 quiet: bool = False # when quiet progress bars are not shown + class Interrogator(): def __init__(self, config: Config): self.config = config @@ -72,6 +92,21 @@ class Interrogator(): self.load_clip_model() + def download_cache(self, clip_model_name: str): + if clip_model_name == 'ViT-L-14/openai': + cache_urls = CACHE_URLS_VITL + elif clip_model_name == 'ViT-H-14/laion2b_s32b_b79k': + cache_urls = CACHE_URLS_VITH + else: + # text embeddings will be precomputed and cached locally + return + + os.makedirs(self.config.cache_path, exist_ok=True) + for url in cache_urls: + filepath = os.path.join(self.config.cache_path, url.split('/')[-1]) + if not os.path.exists(filepath): + _download_file(url, filepath, quiet=self.config.quiet) + def load_clip_model(self): start_time = time.time() config = self.config @@ -105,16 +140,58 @@ class Interrogator(): artists = [f"by {a}" for a in raw_artists] artists.extend([f"inspired by {a}" for a in raw_artists]) + self.download_cache(config.clip_model_name) + self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) + self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config) end_time = time.time() if not config.quiet: print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") + def chain( + self, + image_features: torch.Tensor, + phrases: List[str], + best_prompt: str="", + best_sim: float=0, + max_count: int=32, + desc="Chaining", + reverse: bool=False + ) -> str: + phrases = set(phrases) + if not best_prompt: + best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse) + best_sim = self.similarity(image_features, best_prompt) + phrases.remove(best_prompt) + + def check(addition: str) -> bool: + nonlocal best_prompt, best_sim + prompt = best_prompt + ", " + addition + sim = self.similarity(image_features, prompt) + if reverse: + sim = -sim + if sim > best_sim: + best_sim = sim + best_prompt = prompt + return True + return False + + for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet): + best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse) + flave = best[len(best_prompt)+2:] + if not check(flave): + break + if _prompt_at_max_len(best_prompt, self.tokenize): + break + phrases.remove(flave) + + return best_prompt + def generate_caption(self, pil_image: Image) -> str: if self.config.blip_offload: self.blip_model = self.blip_model.to(self.device) @@ -145,6 +222,8 @@ class Interrogator(): return image_features def interrogate_classic(self, image: Image, max_flavors: int=3) -> str: + """Classic mode creates a prompt in a standard format first describing the image, + then listing the artist, trending, movement, and flavor text modifiers.""" caption = self.generate_caption(image) image_features = self.image_to_features(image) @@ -162,69 +241,43 @@ class Interrogator(): return _truncate_to_fit(prompt, self.tokenize) def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: + """Fast mode simply adds the top ranked terms after a caption. It generally results in + better similarity between generated prompt and image than classic mode, but the prompts + are less readable.""" caption = self.generate_caption(image) image_features = self.image_to_features(image) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) tops = merged.rank(image_features, max_flavors) return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) + def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str: + """Negative mode chains together the most dissimilar terms to the image. It can be used + to help build a negative prompt to pair with the regular positive prompt and often + improve the results of generated images particularly with Stable Diffusion 2.""" + image_features = self.image_to_features(image) + flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count, reverse=True) + flaves = flaves + self.negative.labels + return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain") + def interrogate(self, image: Image, max_flavors: int=32) -> str: caption = self.generate_caption(image) image_features = self.image_to_features(image) - flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count) - best_medium = self.mediums.rank(image_features, 1)[0] - best_artist = self.artists.rank(image_features, 1)[0] - best_trending = self.trendings.rank(image_features, 1)[0] - best_movement = self.movements.rank(image_features, 1)[0] + merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) + flaves = merged.rank(image_features, self.config.flavor_intermediate_count) best_prompt = caption best_sim = self.similarity(image_features, best_prompt) - def check(addition: str) -> bool: - nonlocal best_prompt, best_sim - prompt = best_prompt + ", " + addition - sim = self.similarity(image_features, prompt) - if sim > best_sim: - best_sim = sim - best_prompt = prompt - return True - return False - - def check_multi_batch(opts: List[str]): - nonlocal best_prompt, best_sim - prompts = [] - for i in range(2**len(opts)): - prompt = best_prompt - for bit in range(len(opts)): - if i & (1 << bit): - prompt += ", " + opts[bit] - prompts.append(prompt) - - t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config) - best_prompt = t.rank(image_features, 1)[0] - best_sim = self.similarity(image_features, best_prompt) - - check_multi_batch([best_medium, best_artist, best_trending, best_movement]) - - extended_flavors = set(flaves) - for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet): - best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) - flave = best[len(best_prompt)+2:] - if not check(flave): - break - if _prompt_at_max_len(best_prompt, self.tokenize): - break - extended_flavors.remove(flave) - - return best_prompt + return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain") - def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str: + def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: text_tokens = self.tokenize([text for text in text_array]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) - text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T + if reverse: + similarity = -similarity return text_array[similarity.argmax().item()] def similarity(self, image_features: torch.Tensor, text: str) -> float: @@ -235,6 +288,14 @@ class Interrogator(): similarity = text_features @ image_features.T return similarity[0][0].item() + def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]: + text_tokens = self.tokenize([text for text in text_array]).to(self.device) + with torch.no_grad(), torch.cuda.amp.autocast(): + text_features = self.clip_model.encode_text(text_tokens) + text_features /= text_features.norm(dim=-1, keepdim=True) + similarity = text_features @ image_features.T + return similarity.T[0].tolist() + class LabelTable(): def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): @@ -286,17 +347,19 @@ class LabelTable(): if self.device == 'cpu' or self.device == torch.device('cpu'): self.embeds = [e.astype(np.float32) for e in self.embeds] - def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: + def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: top_count = min(top_count, len(text_embeds)) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) with torch.cuda.amp.autocast(): similarity = image_features @ text_embeds.T + if reverse: + similarity = -similarity _, top_labels = similarity.float().cpu().topk(top_count, dim=-1) return [top_labels[0][i].numpy() for i in range(top_count)] - def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]: + def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]: if len(self.labels) <= self.chunk_size: - tops = self._rank(image_features, self.embeds, top_count=top_count) + tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse) return [self.labels[i] for i in tops] num_chunks = int(math.ceil(len(self.labels)/self.chunk_size)) @@ -306,7 +369,7 @@ class LabelTable(): for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): start = chunk_idx*self.chunk_size stop = min(start+self.chunk_size, len(self.embeds)) - tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) + tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse) top_labels.extend([self.labels[start+i] for i in tops]) top_embeds.extend([self.embeds[start+i] for i in tops]) @@ -314,6 +377,18 @@ class LabelTable(): return [top_labels[i] for i in tops] +def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bool = False): + r = requests.get(url, stream=True) + file_size = int(r.headers.get("Content-Length", 0)) + filename = url.split("/")[-1] + progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet) + with open(filepath, "wb") as f: + for chunk in r.iter_content(chunk_size=chunk_size): + if chunk: + f.write(chunk) + progress.update(len(chunk)) + progress.close() + def _load_list(data_path: str, filename: str) -> List[str]: with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: items = [line.strip() for line in f.readlines()] diff --git a/clip_interrogator/data/negative.txt b/clip_interrogator/data/negative.txt new file mode 100644 index 0000000..7d39d47 --- /dev/null +++ b/clip_interrogator/data/negative.txt @@ -0,0 +1,41 @@ +3d +b&w +bad anatomy +bad art +blur +blurry +cartoon +childish +close up +deformed +disconnected limbs +disfigured +disgusting +extra limb +extra limbs +floating limbs +grain +illustration +kitsch +long body +long neck +low quality +low-res +malformed hands +mangled +missing limb +mutated +mutation +mutilated +noisy +old +out of focus +over saturation +oversaturated +poorly drawn +poorly drawn face +poorly drawn hands +render +surreal +ugly +weird colors \ No newline at end of file diff --git a/run_gradio.py b/run_gradio.py index 9fc685f..c8f1597 100755 --- a/run_gradio.py +++ b/run_gradio.py @@ -3,7 +3,7 @@ import argparse import gradio as gr import open_clip import torch -from clip_interrogator import Interrogator, Config +from clip_interrogator import Config, Interrogator parser = argparse.ArgumentParser() parser.add_argument('-s', '--share', action='store_true', help='Create a public link') @@ -14,40 +14,76 @@ if not torch.cuda.is_available(): ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) -def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): +def image_analysis(image, clip_model_name): + if clip_model_name != ci.config.clip_model_name: + ci.config.clip_model_name = clip_model_name + ci.load_clip_model() + + image = image.convert('RGB') + image_features = ci.image_to_features(image) + + top_mediums = ci.mediums.rank(image_features, 5) + top_artists = ci.artists.rank(image_features, 5) + top_movements = ci.movements.rank(image_features, 5) + top_trendings = ci.trendings.rank(image_features, 5) + top_flavors = ci.flavors.rank(image_features, 5) + + medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))} + artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))} + movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))} + trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))} + flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))} + + return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks + +def image_to_prompt(image, mode, clip_model_name): if clip_model_name != ci.config.clip_model_name: ci.config.clip_model_name = clip_model_name ci.load_clip_model() - ci.config.blip_max_length = int(blip_max_length) - ci.config.blip_num_beams = int(blip_num_beams) image = image.convert('RGB') if mode == 'best': return ci.interrogate(image) elif mode == 'classic': return ci.interrogate_classic(image) - else: + elif mode == 'fast': return ci.interrogate_fast(image) + elif mode == 'negative': + return ci.interrogate_negative(image) + models = ['/'.join(x) for x in open_clip.list_pretrained()] -inputs = [ - gr.inputs.Image(type='pil'), - gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), - gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'), - gr.Number(value=32, label='Caption Max Length'), - gr.Number(value=64, label='Caption Num Beams'), -] -outputs = [ - gr.outputs.Textbox(label="Output"), -] - -io = gr.Interface( - inference, - inputs, - outputs, - title="🕵️♂️ CLIP Interrogator 🕵️♂️", - allow_flagging=False, -) -io.launch(share=args.share) +def prompt_tab(): + with gr.Column(): + with gr.Row(): + image = gr.Image(type='pil', label="Image") + with gr.Column(): + mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best') + model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model') + prompt = gr.Textbox(label="Prompt") + button = gr.Button("Generate prompt") + button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt) + +def analyze_tab(): + with gr.Column(): + with gr.Row(): + image = gr.Image(type='pil', label="Image") + model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model') + with gr.Row(): + medium = gr.Label(label="Medium", num_top_classes=5) + artist = gr.Label(label="Artist", num_top_classes=5) + movement = gr.Label(label="Movement", num_top_classes=5) + trending = gr.Label(label="Trending", num_top_classes=5) + flavor = gr.Label(label="Flavor", num_top_classes=5) + button = gr.Button("Analyze") + button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor]) + +with gr.Blocks() as ui: + gr.Markdown("#