diff --git a/README.md b/README.md index 86e0502..1e4a08d 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,12 @@ Run Version 2 on Colab, HuggingFace, and Replicate!
+For **Stable Diffusion 2.0** prompting use the `ViT-H` version: + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/open-clip/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/fffiloni/CLIP-Interrogator-2) + +
+ Version 1 still available in Colab for comparing different CLIP models [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) @@ -30,7 +36,6 @@ source ci_env/bin/activate Install with PIP ``` -pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip pip install clip-interrogator ``` @@ -40,6 +45,10 @@ You can then use it in your script from PIL import Image from clip_interrogator import Interrogator, Config image = Image.open(image_path).convert('RGB') -ci = Interrogator(Config(clip_model_name="ViT-L/14")) +ci = Interrogator(Config(clip_model_name="ViT-L-14/openai")) print(ci.interrogate(image)) ``` + +CLIP Interrogator uses OpenCLIP which supports many different pretrained CLIP models. For the best prompts for +Stable Diffusion 1.X use `ViT-L-14/openai` for clip_model_name. For Stable Diffusion 2.0 use `ViT-H-14/laion2b_s32b_b79k` + diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb index 7329580..1dba6e6 100644 --- a/clip_interrogator.ipynb +++ b/clip_interrogator.ipynb @@ -46,12 +46,12 @@ "outputs": [], "source": [ "#@title Setup\n", - "import subprocess\n", + "import os, subprocess\n", "\n", "def setup():\n", " install_cmds = [\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", - " ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n", + " ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", " ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ]\n", @@ -60,20 +60,41 @@ "\n", "setup()\n", "\n", + "# download cache files\n", + "print(\"Download preprocessed cache files...\")\n", + "CACHE_URLS = [\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n", + " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n", + "]\n", + "os.makedirs('cache', exist_ok=True)\n", + "for url in CACHE_URLS:\n", + " print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", + "\n", "import sys\n", "sys.path.append('src/blip')\n", - "sys.path.append('src/clip')\n", "sys.path.append('clip-interrogator')\n", "\n", "import gradio as gr\n", "from clip_interrogator import Config, Interrogator\n", "\n", - "ci = Interrogator(Config())\n", + "config = Config()\n", + "config.blip_num_beams = 64\n", + "config.blip_offload = False\n", + "config.chunk_size = 2048\n", + "config.flavor_intermediate_count = 2048\n", + "\n", + "ci = Interrogator(config)\n", "\n", - "def inference(image, mode):\n", + "def inference(image, mode, clip_model_name, best_max_flavors=32):\n", + " if clip_model_name != ci.config.clip_model_name:\n", + " ci.config.clip_model_name = clip_model_name\n", + " ci.load_clip_model()\n", " image = image.convert('RGB')\n", " if mode == 'best':\n", - " return ci.interrogate(image)\n", + " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", " elif mode == 'classic':\n", " return ci.interrogate_classic(image)\n", " else:\n", @@ -132,6 +153,8 @@ "inputs = [\n", " gr.inputs.Image(type='pil'),\n", " gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n", + " gr.Dropdown([\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"], value='ViT-L-14/openai', label='CLIP Model'),\n", + " gr.Number(value=16, label='best mode max flavors'),\n", "]\n", "outputs = [\n", " gr.outputs.Textbox(label=\"Output\"),\n", @@ -170,9 +193,10 @@ "from tqdm import tqdm\n", "\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", - "prompt_mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n", + "prompt_mode = 'best' #@param [\"best\",\"fast\"]\n", "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", "max_filename_len = 128 #@param {type:\"integer\"}\n", + "best_max_flavors = 16 #@param {type:\"integer\"}\n", "\n", "\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", @@ -189,7 +213,7 @@ " clear_output(wait=True)\n", "\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", - " prompt = inference(image, prompt_mode)\n", + " prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n", " prompts.append(prompt)\n", "\n", " print(prompt)\n", @@ -232,7 +256,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3.9.5 ('venv': venv)", + "display_name": "Python 3.8.10 ('ci')", "language": "python", "name": "python3" }, @@ -246,12 +270,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "10f7ca63a88f18f789e6adaf7a045f1bcd3706c5534a32f168d622925241605d" + "hash": "90daa5087f97972f35e673cab20894a33c1e0ca77092ccdd163e60b53596983a" } } }, diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index fb26215..31094ba 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -1,10 +1,11 @@ -import clip import hashlib import inspect import math import numpy as np +import open_clip import os import pickle +import time import torch from dataclasses import dataclass @@ -28,9 +29,11 @@ class Config: blip_max_length: int = 32 blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' blip_num_beams: int = 8 + blip_offload: bool = False # clip settings - clip_model_name: str = 'ViT-L/14' + clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k' + clip_model_path: str = None # interrogator settings cache_path: str = 'cache' @@ -64,14 +67,30 @@ class Interrogator(): else: self.blip_model = config.blip_model + self.load_clip_model() + + def load_clip_model(self): + start_time = time.time() + config = self.config + if config.clip_model is None: if not config.quiet: print("Loading CLIP model...") - self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device) - self.clip_model.to(config.device).eval() + + clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) + self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( + clip_model_name, + pretrained=clip_model_pretrained_name, + precision='fp16', + device=config.device, + jit=False, + cache_dir=config.clip_model_path + ) + self.clip_model.half().to(config.device).eval() else: self.clip_model = config.clip_model self.clip_preprocess = config.clip_preprocess + self.tokenize = open_clip.get_tokenizer(clip_model_name) sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] trending_list = [site for site in sites] @@ -83,13 +102,19 @@ class Interrogator(): artists = [f"by {a}" for a in raw_artists] artists.extend([f"inspired by {a}" for a in raw_artists]) - self.artists = LabelTable(artists, "artists", self.clip_model, config) - self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, config) - self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, config) - self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, config) - self.trendings = LabelTable(trending_list, "trendings", self.clip_model, config) + self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) + self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) + self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) + self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) + self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) + + end_time = time.time() + if not config.quiet: + print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") def generate_caption(self, pil_image: Image) -> str: + if self.config.blip_offload: + self.blip_model = self.blip_model.to(self.device) size = self.config.blip_image_eval_size gpu_image = transforms.Compose([ transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), @@ -105,13 +130,15 @@ class Interrogator(): max_length=self.config.blip_max_length, min_length=5 ) + if self.config.blip_offload: + self.blip_model = self.blip_model.to("cpu") return caption[0] def image_to_features(self, image: Image) -> torch.Tensor: images = self.clip_preprocess(image).unsqueeze(0).to(self.device) - with torch.no_grad(): - image_features = self.clip_model.encode_image(images).float() - image_features /= image_features.norm(dim=-1, keepdim=True) + with torch.no_grad(), torch.cuda.amp.autocast(): + image_features = self.clip_model.encode_image(images) + image_features /= image_features.norm(dim=-1, keepdim=True) return image_features def interrogate_classic(self, image: Image, max_flavors: int=3) -> str: @@ -129,14 +156,14 @@ class Interrogator(): else: prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" - return _truncate_to_fit(prompt) + return _truncate_to_fit(prompt, self.tokenize) def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: caption = self.generate_caption(image) image_features = self.image_to_features(image) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) tops = merged.rank(image_features, max_flavors) - return _truncate_to_fit(caption + ", " + ", ".join(tops)) + return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) def interrogate(self, image: Image, max_flavors: int=32) -> str: caption = self.generate_caption(image) @@ -171,7 +198,7 @@ class Interrogator(): prompt += ", " + opts[bit] prompts.append(prompt) - t = LabelTable(prompts, None, self.clip_model, self.config) + t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config) best_prompt = t.rank(image_features, 1)[0] best_sim = self.similarity(image_features, best_prompt) @@ -179,47 +206,41 @@ class Interrogator(): extended_flavors = set(flaves) for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet): - try: - best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) - flave = best[len(best_prompt)+2:] - if not check(flave): - break - extended_flavors.remove(flave) - except: - # exceeded max prompt length + best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) + flave = best[len(best_prompt)+2:] + if not check(flave): break + if _prompt_at_max_len(best_prompt, self.tokenize): + break + extended_flavors.remove(flave) return best_prompt - def rank_top(self, image_features, text_array: List[str]) -> str: - text_tokens = clip.tokenize([text for text in text_array]).to(self.device) - with torch.no_grad(): - text_features = self.clip_model.encode_text(text_tokens).float() - text_features /= text_features.norm(dim=-1, keepdim=True) - - similarity = torch.zeros((1, len(text_array)), device=self.device) - for i in range(image_features.shape[0]): - similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) + def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str: + text_tokens = self.tokenize([text for text in text_array]).to(self.device) + with torch.no_grad(), torch.cuda.amp.autocast(): + text_features = self.clip_model.encode_text(text_tokens) + text_features /= text_features.norm(dim=-1, keepdim=True) + similarity = text_features @ image_features.T + return text_array[similarity.argmax().item()] - _, top_labels = similarity.cpu().topk(1, dim=-1) - return text_array[top_labels[0][0].numpy()] - - def similarity(self, image_features, text) -> np.float32: - text_tokens = clip.tokenize([text]).to(self.device) - with torch.no_grad(): - text_features = self.clip_model.encode_text(text_tokens).float() - text_features /= text_features.norm(dim=-1, keepdim=True) - similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T - return similarity[0][0] + def similarity(self, image_features: torch.Tensor, text: str) -> float: + text_tokens = self.tokenize([text]).to(self.device) + with torch.no_grad(), torch.cuda.amp.autocast(): + text_features = self.clip_model.encode_text(text_tokens) + text_features /= text_features.norm(dim=-1, keepdim=True) + similarity = text_features @ image_features.T + return similarity[0][0].item() class LabelTable(): - def __init__(self, labels:List[str], desc:str, clip_model, config: Config): + def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): self.chunk_size = config.chunk_size self.config = config self.device = config.device self.embeds = [] self.labels = labels + self.tokenize = tokenize hash = hashlib.sha256(",".join(labels).encode()).hexdigest() @@ -239,11 +260,11 @@ class LabelTable(): self.embeds = [] chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): - text_tokens = clip.tokenize(chunk).to(self.device) - with torch.no_grad(): - text_features = clip_model.encode_text(text_tokens).float() - text_features /= text_features.norm(dim=-1, keepdim=True) - text_features = text_features.half().cpu().numpy() + text_tokens = self.tokenize(chunk).to(self.device) + with torch.no_grad(), torch.cuda.amp.autocast(): + text_features = clip_model.encode_text(text_tokens) + text_features /= text_features.norm(dim=-1, keepdim=True) + text_features = text_features.half().cpu().numpy() for i in range(text_features.shape[0]): self.embeds.append(text_features[i]) @@ -256,16 +277,15 @@ class LabelTable(): "model": config.clip_model_name }, f) - def _rank(self, image_features, text_embeds, top_count=1): + def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: top_count = min(top_count, len(text_embeds)) - similarity = torch.zeros((1, len(text_embeds))).to(self.device) - text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(self.device) - for i in range(image_features.shape[0]): - similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1) - _, top_labels = similarity.cpu().topk(top_count, dim=-1) + text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) + with torch.cuda.amp.autocast(): + similarity = image_features @ text_embeds.T + _, top_labels = similarity.float().cpu().topk(top_count, dim=-1) return [top_labels[0][i].numpy() for i in range(top_count)] - def rank(self, image_features, top_count=1) -> List[str]: + def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]: if len(self.labels) <= self.chunk_size: tops = self._rank(image_features, self.embeds, top_count=top_count) return [self.labels[i] for i in tops] @@ -285,23 +305,27 @@ class LabelTable(): return [top_labels[i] for i in tops] -def _load_list(data_path, filename) -> List[str]: +def _load_list(data_path: str, filename: str) -> List[str]: with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: items = [line.strip() for line in f.readlines()] return items def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: - m = LabelTable([], None, None, config) + m = LabelTable([], None, None, None, config) for table in tables: m.labels.extend(table.labels) m.embeds.extend(table.embeds) return m -def _truncate_to_fit(text: str) -> str: - while True: - try: - _ = clip.tokenize([text]) - return text - except: - text = ",".join(text.split(",")[:-1]) - \ No newline at end of file +def _prompt_at_max_len(text: str, tokenize) -> bool: + tokens = tokenize([text]) + return tokens[0][-1] != 0 + +def _truncate_to_fit(text: str, tokenize) -> str: + parts = text.split(', ') + new_text = parts[0] + for part in parts[1:]: + if _prompt_at_max_len(new_text + part, tokenize): + break + new_text += ', ' + part + return new_text diff --git a/cog.yaml b/cog.yaml index c9a4bde..20c411a 100644 --- a/cog.yaml +++ b/cog.yaml @@ -1,6 +1,6 @@ build: gpu: true - cuda: "11.3" + cuda: "11.6" python_version: "3.8" system_packages: - "libgl1-mesa-glx" @@ -10,11 +10,12 @@ build: - "fairscale==0.4.12" - "transformers==4.21.2" - "ftfy==6.1.1" - - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" + - "torch==1.13.0 --extra-index-url=https://download.pytorch.org/whl/cu116" + - "torchvision==0.14.0 --extra-index-url=https://download.pytorch.org/whl/cu116" + - "open_clip_torch==2.7.0" + - "timm==0.4.12" + - "pycocoevalcap==1.2" run: - - pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip - - pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt" + - git clone https://github.com/salesforce/BLIP /root/blip predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index c085797..3a7d0f9 100644 --- a/predict.py +++ b/predict.py @@ -2,17 +2,43 @@ import sys from PIL import Image from cog import BasePredictor, Input, Path -sys.path.extend(["src/clip", "src/blip"]) +sys.path.append('/root/blip') from clip_interrogator import Interrogator, Config class Predictor(BasePredictor): def setup(self): - config = Config(device="cuda:0", clip_model_name='ViT-L/14') - self.ci = Interrogator(config) + self.ci = Interrogator(Config( + blip_model_url='cache/model_large_caption.pth', + clip_model_name="ViT-L-14/openai", + clip_model_path='cache', + device='cuda:0', + )) - def predict(self, image: Path = Input(description="Input image")) -> str: + def predict( + self, + image: Path = Input(description="Input image"), + clip_model_name: str = Input( + default="ViT-L-14/openai", + choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"], + description="Choose ViT-L for Stable Diffusion 1, and ViT-H for Stable Diffusion 2", + ), + mode: str = Input( + default="best", + choices=["best", "fast"], + description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).", + ), + ) -> str: """Run a single prediction on the model""" image = Image.open(str(image)).convert("RGB") - return self.ci.interrogate(image) + self.switch_model(clip_model_name) + if mode == "best": + return self.ci.interrogate(image) + else: + return self.ci.interrogate_fast(image) + + def switch_model(self, clip_model_name: str): + if clip_model_name != self.ci.config.clip_model_name: + self.ci.config.clip_model_name = clip_model_name + self.ci.load_clip_model() diff --git a/requirements.txt b/requirements.txt index 8bfd9cd..735e90b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ torch torchvision Pillow requests -tqdm \ No newline at end of file +tqdm +open_clip_torch \ No newline at end of file diff --git a/run_cli.py b/run_cli.py index 453abc6..efdd18b 100755 --- a/run_cli.py +++ b/run_cli.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import argparse -import clip import csv +import open_clip import os import requests import torch @@ -19,7 +19,7 @@ def inference(ci, image, mode): def main(): parser = argparse.ArgumentParser() - parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use') + parser.add_argument('-c', '--clip', default='ViT-L-14/openai', help='name of CLIP model to use') parser.add_argument('-f', '--folder', help='path to folder of images') parser.add_argument('-i', '--image', help='image file or url') parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') @@ -34,9 +34,10 @@ def main(): exit(1) # validate clip model name - if args.clip not in clip.available_models(): + models = ['/'.join(x) for x in open_clip.list_pretrained()] + if args.clip not in models: print(f"Could not find CLIP model {args.clip}!") - print(f" available models: {clip.available_models()}") + print(f" available models: {models}") exit(1) # generate a nice prompt diff --git a/run_gradio.py b/run_gradio.py index 3d92498..cada8ce 100755 --- a/run_gradio.py +++ b/run_gradio.py @@ -1,14 +1,19 @@ #!/usr/bin/env python3 -import clip +import argparse import gradio as gr +import open_clip from clip_interrogator import Interrogator, Config -ci = Interrogator(Config()) +parser = argparse.ArgumentParser() +parser.add_argument('-s', '--share', action='store_true', help='Create a public link') +args = parser.parse_args() + +ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): - global ci if clip_model_name != ci.config.clip_model_name: - ci = Interrogator(Config(clip_model_name=clip_model_name)) + ci.config.clip_model_name = clip_model_name + ci.load_clip_model() ci.config.blip_max_length = int(blip_max_length) ci.config.blip_num_beams = int(blip_num_beams) @@ -19,11 +24,13 @@ def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): return ci.interrogate_classic(image) else: return ci.interrogate_fast(image) - + +models = ['/'.join(x) for x in open_clip.list_pretrained()] + inputs = [ gr.inputs.Image(type='pil'), gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), - gr.Dropdown(clip.available_models(), value='ViT-L/14', label='CLIP Model'), + gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'), gr.Number(value=32, label='Caption Max Length'), gr.Number(value=64, label='Caption Num Beams'), ] @@ -38,4 +45,5 @@ io = gr.Interface( title="🕵️‍♂️ CLIP Interrogator 🕵️‍♂️", allow_flagging=False, ) -io.launch() +io.launch(share=args.share) +