{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "3jm8RYrLqvzz"
},
"source": [
"# CLIP Interrogator 2.3 [negative prompt experiment!]\n",
"\n",
"This experimental version of CLIP Interrogator supports finding good \"negative\" prompts for Stable Diffusion 2. Note this is very *WIP* and more work needs to be done building out the dataset to support this (and perhaps a reverse BLIP) so for many images it may struggle to find a well aligned negative prompt. Alignments are displayed to help see how well it did.\n",
"\n",
"
\n",
"\n",
"For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"id": "aP9FjmWxtLKJ"
},
"outputs": [],
"source": [
"#@title Check GPU\n",
"!nvidia-smi -L"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"cellView": "form",
"id": "xpPKQR40qvz2"
},
"outputs": [],
"source": [
"#@title Setup\n",
"import os, subprocess\n",
"\n",
"def setup():\n",
" install_cmds = [\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n",
" ['git', 'clone', '-b', 'negative', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n",
" for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"setup()\n",
"\n",
"\n",
"clip_model_name = 'ViT-H-14/laion2b_s32b_b79k' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"\n",
"print(\"Download preprocessed cache files...\")\n",
"CACHE_URLS = [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n",
"] if clip_model_name == 'ViT-L-14/openai' else [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
"]\n",
"os.makedirs('cache', exist_ok=True)\n",
"for url in CACHE_URLS:\n",
" print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"\n",
"import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n",
"\n",
"config = Config()\n",
"config.blip_num_beams = 64\n",
"config.blip_offload = False\n",
"config.clip_model_name = clip_model_name\n",
"ci = Interrogator(config)\n",
"\n",
"def inference(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n",
" prompt = \"\"\n",
" if mode == 'best':\n",
" prompt = ci.interrogate(image)\n",
" elif mode == 'classic':\n",
" prompt = ci.interrogate_classic(image)\n",
" elif mode == 'fast':\n",
" prompt = ci.interrogate_fast(image)\n",
" elif mode == 'negative':\n",
" image_features = ci.image_to_features(image)\n",
" flaves = ci.flavors.rank(image_features, ci.config.flavor_intermediate_count, reverse=True)\n",
" flaves = flaves + ci.negative.labels\n",
" prompt = ci.chain(image_features, flaves, max_count=32, reverse=True, desc=\"Negative chain\")\n",
" sim = ci.similarity(ci.image_to_features(image), prompt)\n",
" return prompt, sim\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"cellView": "form",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 677
},
"id": "Pf6qkFG6MPRj",
"outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
"\n",
"Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n",
"Note: opening the browser inspector may crash Embedded Colab Mode.\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n