@ -7,13 +7,24 @@
"id": "3jm8RYrLqvzz"
"id": "3jm8RYrLqvzz"
},
},
"source": [
"source": [
"# CLIP Interrogator 2.3 [negative prompt experiment!] \n",
"# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
"\n",
"This experimental version of CLIP Interrogator supports finding good \"negative\" prompts for Stable Diffusion 2. Note this is very *WIP* and more work needs to be done building out the dataset to support this (and perhaps a reverse BLIP) so for many images it may struggle to find a well aligned negative prompt. Alignments are displayed to help see how well it did. \n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers! \n",
"\n",
"\n",
"<br>\n",
"<br>\n",
"\n",
"\n",
"For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n"
"For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n",
"\n",
"This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n",
"\n",
"You can also run this on HuggingFace and Replicate<br>\n",
"[](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [](https://replicate.com/pharmapsychotic/clip-interrogator)\n",
"\n",
"<br>\n",
"\n",
"If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
"\n",
"And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n"
]
]
},
},
{
{
@ -45,8 +56,7 @@
" install_cmds = [\n",
" install_cmds = [\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'git+https://github.com/pharmapsychotic/BLIP.git'],\n",
" ['pip', 'install', 'clip-interrogator'],\n",
" ['git', 'clone', '-b', 'negative', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n",
" ]\n",
" for cmd in install_cmds:\n",
" for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
@ -54,29 +64,8 @@
"setup()\n",
"setup()\n",
"\n",
"\n",
"\n",
"\n",
"clip_model_name = 'ViT-H-14/laion2b_s32b_b79k' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"\n",
"\n",
"print(\"Download preprocessed cache files...\")\n",
"CACHE_URLS = [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n",
"] if clip_model_name == 'ViT-L-14/openai' else [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
"]\n",
"os.makedirs('cache', exist_ok=True)\n",
"for url in CACHE_URLS:\n",
" print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"import sys\n",
"sys.path.append('clip-interrogator')\n",
"\n",
"\n",
"import gradio as gr\n",
"import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n",
"from clip_interrogator import Config, Interrogator\n",
@ -87,24 +76,37 @@
"config.clip_model_name = clip_model_name\n",
"config.clip_model_name = clip_model_name\n",
"ci = Interrogator(config)\n",
"ci = Interrogator(config)\n",
"\n",
"\n",
"def inference(image, mode):\n",
"def image_analysis(image):\n",
" image = image.convert('RGB')\n",
" image_features = ci.image_to_features(image)\n",
"\n",
" top_mediums = ci.mediums.rank(image_features, 5)\n",
" top_artists = ci.artists.rank(image_features, 5)\n",
" top_movements = ci.movements.rank(image_features, 5)\n",
" top_trendings = ci.trendings.rank(image_features, 5)\n",
" top_flavors = ci.flavors.rank(image_features, 5)\n",
"\n",
" medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}\n",
" artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}\n",
" movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}\n",
" trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}\n",
" flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}\n",
" \n",
" return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks\n",
"\n",
"def image_to_prompt(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n",
" image = image.convert('RGB')\n",
" prompt = \"\"\n",
" if mode == 'best':\n",
" if mode == 'best':\n",
" prompt = ci.interrogate(image)\n",
" return ci.interrogate(image)\n",
" elif mode == 'classic':\n",
" elif mode == 'classic':\n",
" prompt = ci.interrogate_classic(image)\n",
" return ci.interrogate_classic(image)\n",
" elif mode == 'fast':\n",
" elif mode == 'fast':\n",
" prompt = ci.interrogate_fast(image)\n",
" return ci.interrogate_fast(image)\n",
" elif mode == 'negative':\n",
" elif mode == 'negative':\n",
" image_features = ci.image_to_features(image)\n",
" return ci.interrogate_negative(image)\n",
" flaves = ci.flavors.rank(image_features, ci.config.flavor_intermediate_count, reverse=True)\n",
" "
" flaves = flaves + ci.negative.labels\n",
" prompt = ci.chain(image_features, flaves, max_count=32, reverse=True, desc=\"Negative chain\")\n",
" sim = ci.similarity(ci.image_to_features(image), prompt)\n",
" return prompt, sim"
]
]
},
},
{
{
@ -156,22 +158,36 @@
"source": [
"source": [
"#@title Image to prompt! 🖼️ -> 📝\n",
"#@title Image to prompt! 🖼️ -> 📝\n",
" \n",
" \n",
"inputs = [\n",
"def prompt_tab():\n",
" gr.inputs.Image(type='pil'),\n",
" with gr.Column():\n",
" gr.Radio(['best', 'fast', 'negative'], label='Mode', value='best'),\n",
" with gr.Row():\n",
"]\n",
" image = gr.Image(type='pil', label=\"Image\")\n",
"outputs = [\n",
" with gr.Column():\n",
" gr.outputs.Textbox(label=\"Output\"),\n",
" mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')\n",
" gr.Number(label=\"Alignment\"),\n",
" prompt = gr.Textbox(label=\"Prompt\")\n",
"]\n",
" button = gr.Button(\"Generate prompt\")\n",
"\n",
" button.click(image_to_prompt, inputs=[image, mode], outputs=prompt)\n",
"io = gr.Interface(\n",
"\n",
" inference, \n",
"def analyze_tab():\n",
" inputs, \n",
" with gr.Column():\n",
" outputs, \n",
" with gr.Row():\n",
" allow_flagging=False,\n",
" image = gr.Image(type='pil', label=\"Image\")\n",
")\n",
" with gr.Row():\n",
"io.launch(debug=False)\n"
" medium = gr.Label(label=\"Medium\", num_top_classes=5)\n",
" artist = gr.Label(label=\"Artist\", num_top_classes=5) \n",
" movement = gr.Label(label=\"Movement\", num_top_classes=5)\n",
" trending = gr.Label(label=\"Trending\", num_top_classes=5)\n",
" flavor = gr.Label(label=\"Flavor\", num_top_classes=5)\n",
" button = gr.Button(\"Analyze\")\n",
" button.click(image_analysis, inputs=image, outputs=[medium, artist, movement, trending, flavor])\n",
"\n",
"with gr.Blocks() as ui:\n",
" with gr.Tab(\"Prompt\"):\n",
" prompt_tab()\n",
" with gr.Tab(\"Analyze\"):\n",
" analyze_tab()\n",
"\n",
"ui.launch(show_api=False, debug=False)\n"
]
]
},
},
{
{
@ -198,10 +214,9 @@
"from tqdm import tqdm\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"prompt_mode = 'best' #@param [\"best\",\"fast\"]\n",
"prompt_mode = 'best' #@param [\"best\",\"fast\",\"classic\",\"negative\" ]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n",
"best_max_flavors = 16 #@param {type:\"integer\"}\n",
"\n",
"\n",
"\n",
"\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
@ -218,7 +233,7 @@
" clear_output(wait=True)\n",
" clear_output(wait=True)\n",
"\n",
"\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors )\n",
" prompt = image_to_prompt(image, prompt_mode )\n",
" prompts.append(prompt)\n",
" prompts.append(prompt)\n",
"\n",
"\n",
" print(prompt)\n",
" print(prompt)\n",
@ -261,7 +276,7 @@
"provenance": []
"provenance": []
},
},
"kernelspec": {
"kernelspec": {
"display_name": "ci ",
"display_name": "Python 3.7.15 ('py37') ",
"language": "python",
"language": "python",
"name": "python3"
"name": "python3"
},
},
@ -275,12 +290,12 @@
"name": "python",
"name": "python",
"nbconvert_exporter": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"pygments_lexer": "ipython3",
"version": "3.8.10 (default, Nov 14 2022, 12:59:47) \n[GCC 9.4.0 ]"
"version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64) ]"
},
},
"orig_nbformat": 4,
"orig_nbformat": 4,
"vscode": {
"vscode": {
"interpreter": {
"interpreter": {
"hash": "90daa5087f97972f35e673cab20894a33c1e0ca77092ccdd163e60b53596983a "
"hash": "1f51d5616d3bc2b87a82685314c5be1ec9a49b6e0cb1f707bfa2acb6c45f3e5f "
}
}
}
}
},
},