{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "3jm8RYrLqvzz" }, "source": [ "# CLIP Interrogator 2.1 ViT-H special edition!\n", "\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "\n", "This version is specialized for producing nice prompts for use with **[Stable Diffusion 2.0](https://stability.ai/blog/stable-diffusion-v2-release)** using the **ViT-H-14** OpenCLIP model!\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "id": "aP9FjmWxtLKJ" }, "outputs": [], "source": [ "#@title Check GPU\n", "!nvidia-smi -L" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "id": "xpPKQR40qvz2" }, "outputs": [], "source": [ "#@title Setup\n", "import os, subprocess\n", "\n", "def setup():\n", " install_cmds = [\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", " ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", " ['git', 'clone', '-b', 'open-clip', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ]\n", " for cmd in install_cmds:\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", "\n", "setup()\n", "\n", "# download cache files\n", "print(\"Download preprocessed cache files...\")\n", "CACHE_URLS = [\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n", "]\n", "os.makedirs('cache', exist_ok=True)\n", "for url in CACHE_URLS:\n", " print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", "\n", "import sys\n", "sys.path.append('src/blip')\n", "sys.path.append('clip-interrogator')\n", "\n", "import gradio as gr\n", "from clip_interrogator import Config, Interrogator\n", "\n", "config = Config()\n", "config.blip_offload = True\n", "config.chunk_size = 2048\n", "config.flavor_intermediate_count = 512\n", "config.blip_num_beams = 64\n", "\n", "ci = Interrogator(config)\n", "\n", "def inference(image, mode, best_max_flavors):\n", " image = image.convert('RGB')\n", " if mode == 'best':\n", " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", " elif mode == 'classic':\n", " return ci.interrogate_classic(image)\n", " else:\n", " return ci.interrogate_fast(image)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "colab": { "base_uri": "https://localhost:8080/", "height": 677 }, "id": "Pf6qkFG6MPRj", "outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n", "\n", "Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n", "Note: opening the browser inspector may crash Embedded Colab Mode.\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n
\n Running on \n https://localhost:${port}${path}\n \n
\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(, 'http://127.0.0.1:7860/', None)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#@title Image to prompt! 🖼️ -> 📝\n", " \n", "inputs = [\n", " gr.inputs.Image(type='pil'),\n", " gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n", " gr.Number(value=4, label='best mode max flavors'),\n", "]\n", "outputs = [\n", " gr.outputs.Textbox(label=\"Output\"),\n", "]\n", "\n", "io = gr.Interface(\n", " inference, \n", " inputs, \n", " outputs, \n", " allow_flagging=False,\n", ")\n", "io.launch()\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "id": "OGmvkzITN4Hz" }, "outputs": [], "source": [ "#@title Batch process a folder of images 📁 -> 📝\n", "\n", "#@markdown This will generate prompts for every image in a folder and either save results \n", "#@markdown to a desc.csv file in the same folder or rename the files to contain their prompts.\n", "#@markdown The renamed files work well for [DreamBooth extension](https://github.com/d8ahazard/sd_dreambooth_extension)\n", "#@markdown in the [Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).\n", "#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n", "\n", "import csv\n", "import os\n", "from IPython.display import clear_output, display\n", "from PIL import Image\n", "from tqdm import tqdm\n", "\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", "prompt_mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n", "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", "max_filename_len = 128 #@param {type:\"integer\"}\n", "best_max_flavors = 4 #@param {type:\"integer\"}\n", "\n", "\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", " name = \"\".join(c for c in prompt if (c.isalnum() or c in \",._-! \"))\n", " name = name.strip()[:(max_len-4)] # extra space for extension\n", " return name\n", "\n", "ci.config.quiet = True\n", "\n", "files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n", "prompts = []\n", "for idx, file in enumerate(tqdm(files, desc='Generating prompts')):\n", " if idx > 0 and idx % 100 == 0:\n", " clear_output(wait=True)\n", "\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", " prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n", " prompts.append(prompt)\n", "\n", " print(prompt)\n", " thumb = image.copy()\n", " thumb.thumbnail([256, 256])\n", " display(thumb)\n", "\n", " if output_mode == 'rename':\n", " name = sanitize_for_filename(prompt, max_filename_len)\n", " ext = os.path.splitext(file)[1]\n", " filename = name + ext\n", " idx = 1\n", " while os.path.exists(os.path.join(folder_path, filename)):\n", " print(f'File {filename} already exists, trying {idx+1}...')\n", " filename = f\"{name}_{idx}{ext}\"\n", " idx += 1\n", " os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename))\n", "\n", "if len(prompts):\n", " if output_mode == 'desc.csv':\n", " csv_path = os.path.join(folder_path, 'desc.csv')\n", " with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n", " w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n", " w.writerow(['image', 'prompt'])\n", " for file, prompt in zip(files, prompts):\n", " w.writerow([file, prompt])\n", "\n", " print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!\")\n", " else:\n", " print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and renamed your files, enjoy!\")\n", "else:\n", " print(f\"Sorry, I couldn't find any images in {folder_path}\")\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "provenance": [] }, "kernelspec": { "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "f7a8d9541664ade9cff251487a19c76f2dd1b4c864d158f07ee26d1b0fd5c9a1" } } }, "nbformat": 4, "nbformat_minor": 0 }