{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "3jm8RYrLqvzz" }, "source": [ "# CLIP Interrogator 2.1 ViT-H special edition!\n", "\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "\n", "This version is specialized for producing nice prompts for use with **[Stable Diffusion 2.0](https://stability.ai/blog/stable-diffusion-v2-release)** using the **ViT-H-14** OpenCLIP model!\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "id": "aP9FjmWxtLKJ" }, "outputs": [], "source": [ "#@title Check GPU\n", "!nvidia-smi -L" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "id": "xpPKQR40qvz2" }, "outputs": [], "source": [ "#@title Setup\n", "import os, subprocess\n", "\n", "def setup():\n", " install_cmds = [\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n", " ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n", " ['git', 'clone', '-b', 'open-clip', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n", " ]\n", " for cmd in install_cmds:\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", "\n", "setup()\n", "\n", "# download cache files\n", "print(\"Download preprocessed cache files...\")\n", "CACHE_URLS = [\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',\n", "]\n", "os.makedirs('cache', exist_ok=True)\n", "for url in CACHE_URLS:\n", " print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", "\n", "import sys\n", "sys.path.append('src/blip')\n", "sys.path.append('clip-interrogator')\n", "\n", "import gradio as gr\n", "from clip_interrogator import Config, Interrogator\n", "\n", "config = Config()\n", "config.blip_offload = True\n", "config.chunk_size = 2048\n", "config.flavor_intermediate_count = 512\n", "config.blip_num_beams = 64\n", "\n", "ci = Interrogator(config)\n", "\n", "def inference(image, mode, best_max_flavors):\n", " image = image.convert('RGB')\n", " if mode == 'best':\n", " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", " elif mode == 'classic':\n", " return ci.interrogate_classic(image)\n", " else:\n", " return ci.interrogate_fast(image)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "colab": { "base_uri": "https://localhost:8080/", "height": 677 }, "id": "Pf6qkFG6MPRj", "outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n", "\n", "Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n", "Note: opening the browser inspector may crash Embedded Colab Mode.\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n