{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "3jm8RYrLqvzz"
      },
      "source": [
        "# CLIP Interrogator 2.3 [negative prompt experiment!]\n",
        "\n",
        "This experimental version of CLIP Interrogator supports finding good \"negative\" prompts for Stable Diffusion 2. Note this is very *WIP* and more work needs to be done building out the dataset to support this (and perhaps a reverse BLIP) so for many images it may struggle to find a well aligned negative prompt. Alignments are displayed to help see how well it did.\n",
        "\n",
        "<br>\n",
        "\n",
        "For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "aP9FjmWxtLKJ"
      },
      "outputs": [],
      "source": [
        "#@title Check GPU\n",
        "!nvidia-smi -L"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "cellView": "form",
        "id": "xpPKQR40qvz2"
      },
      "outputs": [],
      "source": [
        "#@title Setup\n",
        "import os, subprocess\n",
        "\n",
        "def setup():\n",
        "    install_cmds = [\n",
        "        ['pip', 'install', 'gradio'],\n",
        "        ['pip', 'install', 'open_clip_torch'],\n",
        "        ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n",
        "        ['git', 'clone', '-b', 'negative', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
        "    ]\n",
        "    for cmd in install_cmds:\n",
        "        print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
        "\n",
        "setup()\n",
        "\n",
        "\n",
        "clip_model_name = 'ViT-H-14/laion2b_s32b_b79k' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
        "\n",
        "\n",
        "print(\"Download preprocessed cache files...\")\n",
        "CACHE_URLS = [\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n",
        "] if clip_model_name == 'ViT-L-14/openai' else [\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n",
        "    'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n",
        "]\n",
        "os.makedirs('cache', exist_ok=True)\n",
        "for url in CACHE_URLS:\n",
        "    print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
        "\n",
        "\n",
        "import sys\n",
        "sys.path.append('src/blip')\n",
        "sys.path.append('clip-interrogator')\n",
        "\n",
        "import gradio as gr\n",
        "from clip_interrogator import Config, Interrogator\n",
        "\n",
        "config = Config()\n",
        "config.blip_num_beams = 64\n",
        "config.blip_offload = False\n",
        "config.clip_model_name = clip_model_name\n",
        "ci = Interrogator(config)\n",
        "\n",
        "def inference(image, mode):\n",
        "    ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
        "    ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
        "    image = image.convert('RGB')\n",
        "    prompt = \"\"\n",
        "    if mode == 'best':\n",
        "        prompt = ci.interrogate(image)\n",
        "    elif mode == 'classic':\n",
        "        prompt = ci.interrogate_classic(image)\n",
        "    elif mode == 'fast':\n",
        "        prompt = ci.interrogate_fast(image)\n",
        "    elif mode == 'negative':\n",
        "        image_features = ci.image_to_features(image)\n",
        "        flaves = ci.flavors.rank(image_features, ci.config.flavor_intermediate_count, reverse=True)\n",
        "        flaves = flaves + ci.negative.labels\n",
        "        prompt = ci.chain(image_features, flaves, max_count=32, reverse=True, desc=\"Negative chain\")\n",
        "    sim = ci.similarity(ci.image_to_features(image), prompt)\n",
        "    return prompt, sim\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "cellView": "form",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 677
        },
        "id": "Pf6qkFG6MPRj",
        "outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
            "\n",
            "Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n",
            "Note: opening the browser inspector may crash Embedded Colab Mode.\n",
            "\n",
            "To create a public link, set `share=True` in `launch()`.\n"
          ]
        },
        {
          "data": {
            "application/javascript": "(async (port, path, width, height, cache, element) => {\n                        if (!google.colab.kernel.accessAllowed && !cache) {\n                            return;\n                        }\n                        element.appendChild(document.createTextNode(''));\n                        const url = await google.colab.kernel.proxyPort(port, {cache});\n\n                        const external_link = document.createElement('div');\n                        external_link.innerHTML = `\n                            <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n                                Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n                                    https://localhost:${port}${path}\n                                </a>\n                            </div>\n                        `;\n                        element.appendChild(external_link);\n\n                        const iframe = document.createElement('iframe');\n                        iframe.src = new URL(path, url).toString();\n                        iframe.height = height;\n                        iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n                        iframe.width = width;\n                        iframe.style.border = 0;\n                        element.appendChild(iframe);\n                    })(7860, \"/\", \"100%\", 500, false, window.element)",
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "(<gradio.routes.App at 0x7f894e553710>, 'http://127.0.0.1:7860/', None)"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "#@title Image to prompt! 🖼️ -> 📝\n",
        "   \n",
        "inputs = [\n",
        "    gr.inputs.Image(type='pil'),\n",
        "    gr.Radio(['best', 'fast', 'negative'], label='Mode', value='best'),\n",
        "]\n",
        "outputs = [\n",
        "    gr.outputs.Textbox(label=\"Output\"),\n",
        "    gr.Number(label=\"Alignment\"),\n",
        "]\n",
        "\n",
        "io = gr.Interface(\n",
        "    inference, \n",
        "    inputs, \n",
        "    outputs, \n",
        "    allow_flagging=False,\n",
        ")\n",
        "io.launch(debug=False)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "cellView": "form",
        "id": "OGmvkzITN4Hz"
      },
      "outputs": [],
      "source": [
        "#@title Batch process a folder of images 📁 -> 📝\n",
        "\n",
        "#@markdown This will generate prompts for every image in a folder and either save results \n",
        "#@markdown to a desc.csv file in the same folder or rename the files to contain their prompts.\n",
        "#@markdown The renamed files work well for [DreamBooth extension](https://github.com/d8ahazard/sd_dreambooth_extension)\n",
        "#@markdown in the [Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).\n",
        "#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n",
        "\n",
        "import csv\n",
        "import os\n",
        "from IPython.display import clear_output, display\n",
        "from PIL import Image\n",
        "from tqdm import tqdm\n",
        "\n",
        "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
        "prompt_mode = 'best' #@param [\"best\",\"fast\"]\n",
        "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
        "max_filename_len = 128 #@param {type:\"integer\"}\n",
        "best_max_flavors = 16 #@param {type:\"integer\"}\n",
        "\n",
        "\n",
        "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
        "    name = \"\".join(c for c in prompt if (c.isalnum() or c in \",._-! \"))\n",
        "    name = name.strip()[:(max_len-4)] # extra space for extension\n",
        "    return name\n",
        "\n",
        "ci.config.quiet = True\n",
        "\n",
        "files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n",
        "prompts = []\n",
        "for idx, file in enumerate(tqdm(files, desc='Generating prompts')):\n",
        "    if idx > 0 and idx % 100 == 0:\n",
        "        clear_output(wait=True)\n",
        "\n",
        "    image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
        "    prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n",
        "    prompts.append(prompt)\n",
        "\n",
        "    print(prompt)\n",
        "    thumb = image.copy()\n",
        "    thumb.thumbnail([256, 256])\n",
        "    display(thumb)\n",
        "\n",
        "    if output_mode == 'rename':\n",
        "        name = sanitize_for_filename(prompt, max_filename_len)\n",
        "        ext = os.path.splitext(file)[1]\n",
        "        filename = name + ext\n",
        "        idx = 1\n",
        "        while os.path.exists(os.path.join(folder_path, filename)):\n",
        "            print(f'File {filename} already exists, trying {idx+1}...')\n",
        "            filename = f\"{name}_{idx}{ext}\"\n",
        "            idx += 1\n",
        "        os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename))\n",
        "\n",
        "if len(prompts):\n",
        "    if output_mode == 'desc.csv':\n",
        "        csv_path = os.path.join(folder_path, 'desc.csv')\n",
        "        with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n",
        "            w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n",
        "            w.writerow(['image', 'prompt'])\n",
        "            for file, prompt in zip(files, prompts):\n",
        "                w.writerow([file, prompt])\n",
        "\n",
        "        print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!\")\n",
        "    else:\n",
        "        print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and renamed your files, enjoy!\")\n",
        "else:\n",
        "    print(f\"Sorry, I couldn't find any images in {folder_path}\")\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.7.15 ('py37')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]"
    },
    "orig_nbformat": 4,
    "vscode": {
      "interpreter": {
        "hash": "1f51d5616d3bc2b87a82685314c5be1ec9a49b6e0cb1f707bfa2acb6c45f3e5f"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}