diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb
index 5819587..1dba6e6 100644
--- a/clip_interrogator.ipynb
+++ b/clip_interrogator.ipynb
@@ -6,11 +6,21 @@
"id": "3jm8RYrLqvzz"
},
"source": [
- "# CLIP Interrogator 2.1 ViT-H special edition!\n",
+ "# CLIP Interrogator 2.1 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
+ "\n",
+ "
\n",
"\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n",
- "This version is specialized for producing nice prompts for use with **[Stable Diffusion 2.0](https://stability.ai/blog/stable-diffusion-v2-release)** using the **ViT-H-14** OpenCLIP model!\n"
+ "
\n",
+ "\n",
+ "This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n",
+ "\n",
+ "
\n",
+ "\n",
+ "If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
+ "\n",
+ "And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n"
]
},
{
@@ -43,7 +53,7 @@
" ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n",
- " ['git', 'clone', '-b', 'open-clip', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
+ " ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n",
" for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
@@ -71,14 +81,17 @@
"from clip_interrogator import Config, Interrogator\n",
"\n",
"config = Config()\n",
- "config.blip_offload = True\n",
- "config.chunk_size = 2048\n",
- "config.flavor_intermediate_count = 512\n",
"config.blip_num_beams = 64\n",
+ "config.blip_offload = False\n",
+ "config.chunk_size = 2048\n",
+ "config.flavor_intermediate_count = 2048\n",
"\n",
"ci = Interrogator(config)\n",
"\n",
- "def inference(image, mode, best_max_flavors):\n",
+ "def inference(image, mode, clip_model_name, best_max_flavors=32):\n",
+ " if clip_model_name != ci.config.clip_model_name:\n",
+ " ci.config.clip_model_name = clip_model_name\n",
+ " ci.load_clip_model()\n",
" image = image.convert('RGB')\n",
" if mode == 'best':\n",
" return ci.interrogate(image, max_flavors=int(best_max_flavors))\n",
@@ -140,7 +153,8 @@
"inputs = [\n",
" gr.inputs.Image(type='pil'),\n",
" gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n",
- " gr.Number(value=4, label='best mode max flavors'),\n",
+ " gr.Dropdown([\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"], value='ViT-L-14/openai', label='CLIP Model'),\n",
+ " gr.Number(value=16, label='best mode max flavors'),\n",
"]\n",
"outputs = [\n",
" gr.outputs.Textbox(label=\"Output\"),\n",
@@ -179,10 +193,10 @@
"from tqdm import tqdm\n",
"\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
- "prompt_mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n",
+ "prompt_mode = 'best' #@param [\"best\",\"fast\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n",
- "best_max_flavors = 4 #@param {type:\"integer\"}\n",
+ "best_max_flavors = 16 #@param {type:\"integer\"}\n",
"\n",
"\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
@@ -242,7 +256,7 @@
"provenance": []
},
"kernelspec": {
- "display_name": "Python 3.8.10 ('venv': venv)",
+ "display_name": "Python 3.8.10 ('ci')",
"language": "python",
"name": "python3"
},
@@ -261,7 +275,7 @@
"orig_nbformat": 4,
"vscode": {
"interpreter": {
- "hash": "f7a8d9541664ade9cff251487a19c76f2dd1b4c864d158f07ee26d1b0fd5c9a1"
+ "hash": "90daa5087f97972f35e673cab20894a33c1e0ca77092ccdd163e60b53596983a"
}
}
},
diff --git a/predict.py b/predict.py
index a8bc923..3a7d0f9 100644
--- a/predict.py
+++ b/predict.py
@@ -21,17 +21,22 @@ class Predictor(BasePredictor):
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L-14/openai",
- choices=[
- "ViT-L-14/openai",
- "ViT-H-14/laion2b_s32b_b79k",
- ],
- description="Choose a clip model.",
+ choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"],
+ description="Choose ViT-L for Stable Diffusion 1, and ViT-H for Stable Diffusion 2",
+ ),
+ mode: str = Input(
+ default="best",
+ choices=["best", "fast"],
+ description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
),
) -> str:
"""Run a single prediction on the model"""
image = Image.open(str(image)).convert("RGB")
self.switch_model(clip_model_name)
- return self.ci.interrogate(image)
+ if mode == "best":
+ return self.ci.interrogate(image)
+ else:
+ return self.ci.interrogate_fast(image)
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name: