{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "3jm8RYrLqvzz" }, "source": [ "# CLIP Interrogator 2.1 ViTH special edition!\n", "\n", "
\n", "\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "\n", "
\n", "\n", "This version is specialized for producing nice prompts for use with **Stable Diffusion 2.0** using the ViT-H-14 OpenCLIP model!" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "id": "aP9FjmWxtLKJ" }, "outputs": [], "source": [ "#@title Check GPU\n", "!nvidia-smi -L" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "id": "xpPKQR40qvz2" }, "outputs": [], "source": [ "#@title Setup\n", "import os, subprocess\n", "\n", "def setup():\n", " install_cmds = [\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", " ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", " ['git', 'clone', '-b', 'open-clip', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ]\n", " for cmd in install_cmds:\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", "\n", "setup()\n", "\n", "# download cache files\n", "print(\"Download preprocessed cache files...\")\n", "CACHE_URLS = [\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n", "]\n", "os.makedirs('cache', exist_ok=True)\n", "for url in CACHE_URLS:\n", " print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", "\n", "import sys\n", "sys.path.append('src/blip')\n", "sys.path.append('clip-interrogator')\n", "\n", "import gradio as gr\n", "from clip_interrogator import Config, Interrogator\n", "\n", "ci = Interrogator(Config())\n", "\n", "def inference(image, mode):\n", " image = image.convert('RGB')\n", " if mode == 'best':\n", " return ci.interrogate(image)\n", " elif mode == 'classic':\n", " return ci.interrogate_classic(image)\n", " else:\n", " return ci.interrogate_fast(image)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "colab": { "base_uri": "https://localhost:8080/", "height": 677 }, "id": "Pf6qkFG6MPRj", "outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n", "\n", "Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n", "Note: opening the browser inspector may crash Embedded Colab Mode.\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n
\n Running on \n https://localhost:${port}${path}\n \n
\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(, '', None)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@title Image to prompt! 🖼️ -> 📝\n", " \n", "inputs = [\n", " gr.inputs.Image(type='pil'),\n", " gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n", "]\n", "outputs = [\n", " gr.outputs.Textbox(label=\"Output\"),\n", "]\n", "\n", "io = gr.Interface(\n", " inference, \n", " inputs, \n", " outputs, \n", " allow_flagging=False,\n", ")\n", "io.launch()\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "id": "OGmvkzITN4Hz" }, "outputs": [], "source": [ "#@title Batch process a folder of images 📁 -> 📝\n", "\n", "#@markdown This will generate prompts for every image in a folder and either save results \n", "#@markdown to a desc.csv file in the same folder or rename the files to contain their prompts.\n", "#@markdown The renamed files work well for [DreamBooth extension](https://github.com/d8ahazard/sd_dreambooth_extension)\n", "#@markdown in the [Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).\n", "#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n", "\n", "import csv\n", "import os\n", "from IPython.display import clear_output, display\n", "from PIL import Image\n", "from tqdm import tqdm\n", "\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", "prompt_mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n", "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", "max_filename_len = 128 #@param {type:\"integer\"}\n", "\n", "\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", " name = \"\".join(c for c in prompt if (c.isalnum() or c in \",._-! \"))\n", " name = name.strip()[:(max_len-4)] # extra space for extension\n", " return name\n", "\n", "ci.config.quiet = True\n", "\n", "files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n", "prompts = []\n", "for idx, file in enumerate(tqdm(files, desc='Generating prompts')):\n", " if idx > 0 and idx % 100 == 0:\n", " clear_output(wait=True)\n", "\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", " prompt = inference(image, prompt_mode)\n", " prompts.append(prompt)\n", "\n", " print(prompt)\n", " thumb = image.copy()\n", " thumb.thumbnail([256, 256])\n", " display(thumb)\n", "\n", " if output_mode == 'rename':\n", " name = sanitize_for_filename(prompt, max_filename_len)\n", " ext = os.path.splitext(file)[1]\n", " filename = name + ext\n", " idx = 1\n", " while os.path.exists(os.path.join(folder_path, filename)):\n", " print(f'File {filename} already exists, trying {idx+1}...')\n", " filename = f\"{name}_{idx}{ext}\"\n", " idx += 1\n", " os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename))\n", "\n", "if len(prompts):\n", " if output_mode == 'desc.csv':\n", " csv_path = os.path.join(folder_path, 'desc.csv')\n", " with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n", " w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n", " w.writerow(['image', 'prompt'])\n", " for file, prompt in zip(files, prompts):\n", " w.writerow([file, prompt])\n", "\n", " print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!\")\n", " else:\n", " print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and renamed your files, enjoy!\")\n", "else:\n", " print(f\"Sorry, I couldn't find any images in {folder_path}\")\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "provenance": [] }, "kernelspec": { "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "f7a8d9541664ade9cff251487a19c76f2dd1b4c864d158f07ee26d1b0fd5c9a1" } } }, "nbformat": 4, "nbformat_minor": 0 }