Compare commits

..

No commits in common. 'main' and 'replicate' have entirely different histories.

  1. 39
      .github/workflows/python-publish.yml
  2. 2
      .gitignore
  3. 1
      MANIFEST.in
  4. 55
      README.md
  5. 214
      clip_interrogator.ipynb
  6. 4
      clip_interrogator/__init__.py
  7. 535
      clip_interrogator/clip_interrogator.py
  8. 41
      clip_interrogator/data/negative.txt
  9. 20
      cog.yaml
  10. 41
      predict.py
  11. 8
      requirements.txt
  12. 23
      run_cli.py
  13. 118
      run_gradio.py
  14. 4
      setup.py

39
.github/workflows/python-publish.yml

@ -1,39 +0,0 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: pharmapsychotic
password: ${{ secrets.PYPI_API_TOKEN }}

2
.gitignore vendored

@ -1,9 +1,7 @@
*.pyc *.pyc
.cog/ .cog/
.vscode/ .vscode/
bench/
cache/ cache/
ci_env/
clip-interrogator/ clip-interrogator/
clip_interrogator.egg-info/ clip_interrogator.egg-info/
dist/ dist/

1
MANIFEST.in

@ -2,5 +2,4 @@ include clip_interrogator/data/artists.txt
include clip_interrogator/data/flavors.txt include clip_interrogator/data/flavors.txt
include clip_interrogator/data/mediums.txt include clip_interrogator/data/mediums.txt
include clip_interrogator/data/movements.txt include clip_interrogator/data/movements.txt
include clip_interrogator/data/negative.txt
include requirements.txt include requirements.txt

55
README.md

@ -4,17 +4,12 @@
## Run it! ## Run it!
🆕 Now available as a [Stable Diffusion Web UI Extension](https://github.com/pharmapsychotic/clip-interrogator-ext)! 🆕
<br>
Run Version 2 on Colab, HuggingFace, and Replicate! Run Version 2 on Colab, HuggingFace, and Replicate!
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator) [![Lambda](https://img.shields.io/badge/%CE%BB-Lambda-blue)](https://cloud.lambdalabs.com/demos/ml/CLIP-Interrogator) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator)
<br> <br>
Version 1 still available in Colab for comparing different CLIP models Version 1 still available in Colab for comparing different CLIP models
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb)
@ -30,57 +25,21 @@ The **CLIP Interrogator** is a prompt engineering tool that combines OpenAI's [C
Create and activate a Python virtual environment Create and activate a Python virtual environment
```bash ```bash
python3 -m venv ci_env python3 -m venv ci_env
(for linux ) source ci_env/bin/activate source ci_env/bin/activate
(for windows) .\ci_env\Scripts\activate
``` ```
Install with PIP Install with PIP
``` ```
# install torch with GPU support for example: pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip
pip install clip-interrogator
# install clip-interrogator
pip install clip-interrogator==0.5.4
# or for very latest WIP with BLIP2 support
#pip install clip-interrogator==0.6.0
``` ```
You can then use it in your script You can then use it in your script
```python ```python
from PIL import Image from PIL import Image
from clip_interrogator import Config, Interrogator from clip_interrogator import Interrogator, Config
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert('RGB')
ci = Interrogator(Config(clip_model_name="ViT-L-14/openai")) ci = Interrogator(Config(clip_model_name="ViT-L/14"))
print(ci.interrogate(image)) print(ci.interrogate(image))
``` ```
CLIP Interrogator uses OpenCLIP which supports many different pretrained CLIP models. For the best prompts for
Stable Diffusion 1.X use `ViT-L-14/openai` for clip_model_name. For Stable Diffusion 2.0 use `ViT-H-14/laion2b_s32b_b79k`
## Configuration
The `Config` object lets you configure CLIP Interrogator's processing.
* `clip_model_name`: which of the OpenCLIP pretrained CLIP models to use
* `cache_path`: path where to save precomputed text embeddings
* `download_cache`: when True will download the precomputed embeddings from huggingface
* `chunk_size`: batch size for CLIP, use smaller for lower VRAM
* `quiet`: when True no progress bars or text output will be displayed
On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB.
See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes.
## Ranking against your own list of terms (requires version 0.6.0)
```python
from clip_interrogator import Config, Interrogator, LabelTable, load_list
from PIL import Image
ci = Interrogator(Config(blip_model_type=None))
image = Image.open(image_path).convert('RGB')
table = LabelTable(load_list('terms.txt'), 'terms', ci)
best_match = table.rank(ci.image_to_features(image), top_count=1)[0]
print(best_match)
```

214
clip_interrogator.ipynb

@ -1,25 +1,21 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "3jm8RYrLqvzz" "id": "3jm8RYrLqvzz"
}, },
"source": [ "source": [
"# CLIP Interrogator 2.4 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", "# CLIP Interrogator 2.1 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
"<br>\n",
"\n", "\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n", "\n",
"<br>\n", "<br>\n",
"\n", "\n",
"For Stable Diffusion 1.X choose the **ViT-L** model and for Stable Diffusion 2.0+ choose the **ViT-H** CLIP Model.\n",
"\n",
"This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n", "This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n",
"\n", "\n",
"You can also run this on HuggingFace and Replicate<br>\n",
"[![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator)\n",
"\n",
"<br>\n", "<br>\n",
"\n", "\n",
"If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n", "If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
@ -50,67 +46,45 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Setup\n", "#@title Setup\n",
"import os, subprocess\n", "import argparse, subprocess, sys, time\n",
"\n", "\n",
"def setup():\n", "def setup():\n",
" install_cmds = [\n", " install_cmds = [\n",
" ['pip', 'install', 'gradio'],\n", " ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n",
" ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n",
" ['pip', 'install', 'clip-interrogator'],\n", " ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],\n",
" ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n", " ]\n",
" for cmd in install_cmds:\n", " for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", " print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n", "\n",
"setup()\n", "setup()\n",
"\n", "\n",
"import sys\n",
"sys.path.append('src/blip')\n",
"sys.path.append('src/clip')\n",
"sys.path.append('clip-interrogator')\n",
"\n", "\n",
"caption_model_name = 'blip-large' #@param [\"blip-base\", \"blip-large\", \"git-large-coco\"]\n", "import clip\n",
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"import gradio as gr\n", "import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n", "import torch\n",
"from clip_interrogator import Interrogator, Config\n",
"\n", "\n",
"config = Config()\n", "ci = Interrogator(Config())\n",
"config.clip_model_name = clip_model_name\n",
"config.caption_model_name = caption_model_name\n",
"ci = Interrogator(config)\n",
"\n", "\n",
"def image_analysis(image):\n", "def inference(image, mode):\n",
" image = image.convert('RGB')\n",
" image_features = ci.image_to_features(image)\n",
"\n",
" top_mediums = ci.mediums.rank(image_features, 5)\n",
" top_artists = ci.artists.rank(image_features, 5)\n",
" top_movements = ci.movements.rank(image_features, 5)\n",
" top_trendings = ci.trendings.rank(image_features, 5)\n",
" top_flavors = ci.flavors.rank(image_features, 5)\n",
"\n",
" medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}\n",
" artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}\n",
" movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}\n",
" trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}\n",
" flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}\n",
" \n",
" return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks\n",
"\n",
"def image_to_prompt(image, mode):\n",
" ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n", " image = image.convert('RGB')\n",
" if mode == 'best':\n", " if mode == 'best':\n",
" return ci.interrogate(image)\n", " return ci.interrogate(image)\n",
" elif mode == 'classic':\n", " elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n", " return ci.interrogate_classic(image)\n",
" elif mode == 'fast':\n", " else:\n",
" return ci.interrogate_fast(image)\n", " return ci.interrogate_fast(image)\n"
" elif mode == 'negative':\n",
" return ci.interrogate_negative(image)\n",
" "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": { "colab": {
@ -120,45 +94,63 @@
"id": "Pf6qkFG6MPRj", "id": "Pf6qkFG6MPRj",
"outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d" "outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
}, },
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
"\n",
"Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n",
"Note: opening the browser inspector may crash Embedded Colab Mode.\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n https://localhost:${port}${path}\n </a>\n </div>\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)",
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(<gradio.routes.App at 0x7f894e553710>, 'http://127.0.0.1:7860/', None)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"#@title Image to prompt! 🖼 -> 📝\n", "#@title Image to prompt! 🖼 -> 📝\n",
" \n", " \n",
"def prompt_tab():\n", "inputs = [\n",
" with gr.Column():\n", " gr.inputs.Image(type='pil'),\n",
" with gr.Row():\n", " gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n",
" image = gr.Image(type='pil', label=\"Image\")\n", "]\n",
" with gr.Column():\n", "outputs = [\n",
" mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')\n", " gr.outputs.Textbox(label=\"Output\"),\n",
" prompt = gr.Textbox(label=\"Prompt\")\n", "]\n",
" button = gr.Button(\"Generate prompt\")\n", "\n",
" button.click(image_to_prompt, inputs=[image, mode], outputs=prompt)\n", "io = gr.Interface(\n",
"\n", " inference, \n",
"def analyze_tab():\n", " inputs, \n",
" with gr.Column():\n", " outputs, \n",
" with gr.Row():\n", " allow_flagging=False,\n",
" image = gr.Image(type='pil', label=\"Image\")\n", ")\n",
" with gr.Row():\n", "io.launch()\n"
" medium = gr.Label(label=\"Medium\", num_top_classes=5)\n",
" artist = gr.Label(label=\"Artist\", num_top_classes=5) \n",
" movement = gr.Label(label=\"Movement\", num_top_classes=5)\n",
" trending = gr.Label(label=\"Trending\", num_top_classes=5)\n",
" flavor = gr.Label(label=\"Flavor\", num_top_classes=5)\n",
" button = gr.Button(\"Analyze\")\n",
" button.click(image_analysis, inputs=image, outputs=[medium, artist, movement, trending, flavor])\n",
"\n",
"with gr.Blocks() as ui:\n",
" with gr.Tab(\"Prompt\"):\n",
" prompt_tab()\n",
" with gr.Tab(\"Analyze\"):\n",
" analyze_tab()\n",
"\n",
"ui.launch(show_api=False, debug=False)\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"id": "OGmvkzITN4Hz" "id": "OGmvkzITN4Hz"
@ -167,69 +159,43 @@
"source": [ "source": [
"#@title Batch process a folder of images 📁 -> 📝\n", "#@title Batch process a folder of images 📁 -> 📝\n",
"\n", "\n",
"#@markdown This will generate prompts for every image in a folder and either save results \n", "#@markdown This will generate prompts for every image in a folder and save results to desc.csv in the same folder.\n",
"#@markdown to a desc.csv file in the same folder or rename the files to contain their prompts.\n",
"#@markdown The renamed files work well for [DreamBooth extension](https://github.com/d8ahazard/sd_dreambooth_extension)\n",
"#@markdown in the [Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).\n",
"#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n", "#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n",
"#@markdown notebook or use it to train a different model or just print it out for fun. \n",
"#@markdown If you make something cool, I'd love to see it! Tag me on twitter or something. 😀\n",
"\n", "\n",
"import csv\n", "import csv\n",
"import os\n", "import os\n",
"from IPython.display import clear_output, display\n", "from IPython.display import display\n",
"from PIL import Image\n", "from PIL import Image\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"prompt_mode = 'best' #@param [\"best\",\"fast\",\"classic\",\"negative\"]\n", "mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n",
"\n",
"\n", "\n",
"def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
" name = \"\".join(c for c in prompt if (c.isalnum() or c in \",._-! \"))\n",
" name = name.strip()[:(max_len-4)] # extra space for extension\n",
" return name\n",
"\n", "\n",
"ci.config.quiet = True\n", "files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n",
"\n",
"files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n",
"prompts = []\n", "prompts = []\n",
"for idx, file in enumerate(tqdm(files, desc='Generating prompts')):\n", "for file in files:\n",
" if idx > 0 and idx % 100 == 0:\n",
" clear_output(wait=True)\n",
"\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = image_to_prompt(image, prompt_mode)\n", " prompt = inference(image, mode)\n",
" prompts.append(prompt)\n", " prompts.append(prompt)\n",
"\n", "\n",
" print(prompt)\n",
" thumb = image.copy()\n", " thumb = image.copy()\n",
" thumb.thumbnail([256, 256])\n", " thumb.thumbnail([256, 256])\n",
" display(thumb)\n", " display(thumb)\n",
"\n", "\n",
" if output_mode == 'rename':\n", " print(prompt)\n",
" name = sanitize_for_filename(prompt, max_filename_len)\n",
" ext = os.path.splitext(file)[1]\n",
" filename = name + ext\n",
" idx = 1\n",
" while os.path.exists(os.path.join(folder_path, filename)):\n",
" print(f'File {filename} already exists, trying {idx+1}...')\n",
" filename = f\"{name}_{idx}{ext}\"\n",
" idx += 1\n",
" os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename))\n",
"\n", "\n",
"if len(prompts):\n", "if len(prompts):\n",
" if output_mode == 'desc.csv':\n", " csv_path = os.path.join(folder_path, 'desc.csv')\n",
" csv_path = os.path.join(folder_path, 'desc.csv')\n", " with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n",
" with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n", " w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n",
" w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n", " w.writerow(['image', 'prompt'])\n",
" w.writerow(['image', 'prompt'])\n", " for file, prompt in zip(files, prompts):\n",
" for file, prompt in zip(files, prompts):\n", " w.writerow([file, prompt])\n",
" w.writerow([file, prompt])\n", "\n",
"\n", " print(f\"\\n\\n\\n\\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!\")\n",
" print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!\")\n",
" else:\n",
" print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and renamed your files, enjoy!\")\n",
"else:\n", "else:\n",
" print(f\"Sorry, I couldn't find any images in {folder_path}\")\n" " print(f\"Sorry, I couldn't find any images in {folder_path}\")\n"
] ]
@ -242,7 +208,7 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3.7.15 ('py37')", "display_name": "Python 3.8.10 ('venv': venv)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -256,12 +222,12 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.5" "version": "3.8.10"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "1f51d5616d3bc2b87a82685314c5be1ec9a49b6e0cb1f707bfa2acb6c45f3e5f" "hash": "f7a8d9541664ade9cff251487a19c76f2dd1b4c864d158f07ee26d1b0fd5c9a1"
} }
} }
}, },

4
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Config, Interrogator, LabelTable, list_caption_models, list_clip_models, load_list from .clip_interrogator import Interrogator, Config
__version__ = '0.6.0' __version__ = '0.1.4'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

535
clip_interrogator/clip_interrogator.py

@ -1,398 +1,279 @@
import clip
import hashlib import hashlib
import inspect
import math import math
import numpy as np import numpy as np
import open_clip
import os import os
import requests import pickle
import time
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from models.blip import blip_decoder
from PIL import Image from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm from tqdm import tqdm
from typing import List, Optional from typing import List
from safetensors.numpy import load_file, save_file
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB
'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB
'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
'git-large-coco': 'microsoft/git-large-coco', # 1.58GB
}
CACHE_URL_BASE = 'https://huggingface.co/pharmapsychotic/ci-preprocess/resolve/main/'
@dataclass @dataclass
class Config: class Config:
# models can optionally be passed in directly # models can optionally be passed in directly
caption_model = None blip_model = None
caption_processor = None
clip_model = None clip_model = None
clip_preprocess = None clip_preprocess = None
# blip settings # blip settings
caption_max_length: int = 32 blip_image_eval_size: int = 384
caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None blip_max_length: int = 32
caption_offload: bool = False blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
blip_num_beams: int = 8
# clip settings # clip settings
clip_model_name: str = 'ViT-L-14/openai' clip_model_name: str = 'ViT-L/14'
clip_model_path: Optional[str] = None
clip_offload: bool = False
# interrogator settings # interrogator settings
cache_path: str = 'cache' # path to store cached text embeddings cache_path: str = 'cache'
download_cache: bool = True # when true, cached embeds are downloaded from huggingface chunk_size: int = 2048
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), 'data') data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
flavor_intermediate_count: int = 2048 flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown
def apply_low_vram_defaults(self):
self.caption_model_name = 'blip-base'
self.caption_offload = True
self.clip_offload = True
self.chunk_size = 1024
self.flavor_intermediate_count = 1024
class Interrogator(): class Interrogator():
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
self.device = config.device self.device = config.device
self.dtype = torch.float16 if self.device == 'cuda' else torch.float32
self.caption_offloaded = True
self.clip_offloaded = True
self.load_caption_model()
self.load_clip_model()
def load_caption_model(self):
if self.config.caption_model is None and self.config.caption_model_name:
if not self.config.quiet:
print(f"Loading caption model {self.config.caption_model_name}...")
model_path = CAPTION_MODELS[self.config.caption_model_name]
if self.config.caption_model_name.startswith('git-'):
caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
elif self.config.caption_model_name.startswith('blip2-'):
caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
else:
caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
self.caption_processor = AutoProcessor.from_pretrained(model_path)
caption_model.eval()
if not self.config.caption_offload:
caption_model = caption_model.to(self.config.device)
self.caption_model = caption_model
else:
self.caption_model = self.config.caption_model
self.caption_processor = self.config.caption_processor
def load_clip_model(self): if config.blip_model is None:
start_time = time.time() print("Loading BLIP model...")
config = self.config blip_path = os.path.dirname(inspect.getfile(blip_decoder))
configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) med_config = os.path.join(configs_path, 'med_config.json')
blip_model = blip_decoder(
pretrained=config.blip_model_url,
image_size=config.blip_image_eval_size,
vit='large',
med_config=med_config
)
blip_model.eval()
blip_model = blip_model.to(config.device)
self.blip_model = blip_model
else:
self.blip_model = config.blip_model
if config.clip_model is None: if config.clip_model is None:
if not config.quiet: print("Loading CLIP model...")
print(f"Loading CLIP model {config.clip_model_name}...") self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device)
self.clip_model.to(config.device).eval()
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name,
pretrained=clip_model_pretrained_name,
precision='fp16' if config.device == 'cuda' else 'fp32',
device=config.device,
jit=False,
cache_dir=config.clip_model_path
)
self.clip_model.eval()
else: else:
self.clip_model = config.clip_model self.clip_model = config.clip_model
self.clip_preprocess = config.clip_preprocess self.clip_preprocess = config.clip_preprocess
self.tokenize = open_clip.get_tokenizer(clip_model_name)
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribbble', sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount',
'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
trending_list = [site for site in sites] trending_list = [site for site in sites]
trending_list.extend(["trending on "+site for site in sites]) trending_list.extend(["trending on "+site for site in sites])
trending_list.extend(["featured on "+site for site in sites]) trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites]) trending_list.extend([site+" contest winner" for site in sites])
raw_artists = load_list(config.data_path, 'artists.txt') raw_artists = _load_list(config.data_path, 'artists.txt')
artists = [f"by {a}" for a in raw_artists] artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists]) artists.extend([f"inspired by {a}" for a in raw_artists])
self._prepare_clip() self.artists = LabelTable(artists, "artists", self.clip_model, config)
self.artists = LabelTable(artists, "artists", self) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, config)
self.flavors = LabelTable(load_list(config.data_path, 'flavors.txt'), "flavors", self) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, config)
self.mediums = LabelTable(load_list(config.data_path, 'mediums.txt'), "mediums", self) self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, config)
self.movements = LabelTable(load_list(config.data_path, 'movements.txt'), "movements", self) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, config)
self.trendings = LabelTable(trending_list, "trendings", self)
self.negative = LabelTable(load_list(config.data_path, 'negative.txt'), "negative", self)
end_time = time.time()
if not config.quiet:
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
def chain(
self,
image_features: torch.Tensor,
phrases: List[str],
best_prompt: str="",
best_sim: float=0,
min_count: int=8,
max_count: int=32,
desc="Chaining",
reverse: bool=False
) -> str:
self._prepare_clip()
phrases = set(phrases)
if not best_prompt:
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
best_sim = self.similarity(image_features, best_prompt)
phrases.remove(best_prompt)
curr_prompt, curr_sim = best_prompt, best_sim
def check(addition: str, idx: int) -> bool:
nonlocal best_prompt, best_sim, curr_prompt, curr_sim
prompt = curr_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if reverse:
sim = -sim
if sim > best_sim:
best_prompt, best_sim = prompt, sim
if sim > curr_sim or idx < min_count:
curr_prompt, curr_sim = prompt, sim
return True
return False
for idx in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
best = self.rank_top(image_features, [f"{curr_prompt}, {f}" for f in phrases], reverse=reverse)
flave = best[len(curr_prompt)+2:]
if not check(flave, idx):
break
if _prompt_at_max_len(curr_prompt, self.tokenize):
break
phrases.remove(flave)
return best_prompt
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
assert self.caption_model is not None, "No caption model loaded." size = self.config.blip_image_eval_size
self._prepare_caption() gpu_image = transforms.Compose([
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device) transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
if not self.config.caption_model_name.startswith('git-'): transforms.ToTensor(),
inputs = inputs.to(self.dtype) transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length) ])(pil_image).unsqueeze(0).to(self.device)
return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip()
with torch.no_grad():
caption = self.blip_model.generate(
gpu_image,
sample=False,
num_beams=self.config.blip_num_beams,
max_length=self.config.blip_max_length,
min_length=5
)
return caption[0]
def image_to_features(self, image: Image) -> torch.Tensor: def image_to_features(self, image: Image) -> torch.Tensor:
self._prepare_clip()
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad():
image_features = self.clip_model.encode_image(images) image_features = self.clip_model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features return image_features
def interrogate_classic(self, image: Image, max_flavors: int=3, caption: Optional[str]=None) -> str: def interrogate_classic(self, image: Image, max_flaves: int=3) -> str:
"""Classic mode creates a prompt in a standard format first describing the image, caption = self.generate_caption(image)
then listing the artist, trending, movement, and flavor text modifiers."""
caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
medium = self.mediums.rank(image_features, 1)[0] medium = self.mediums.rank(image_features, 1)[0]
artist = self.artists.rank(image_features, 1)[0] artist = self.artists.rank(image_features, 1)[0]
trending = self.trendings.rank(image_features, 1)[0] trending = self.trendings.rank(image_features, 1)[0]
movement = self.movements.rank(image_features, 1)[0] movement = self.movements.rank(image_features, 1)[0]
flaves = ", ".join(self.flavors.rank(image_features, max_flavors)) flaves = ", ".join(self.flavors.rank(image_features, max_flaves))
if caption.startswith(medium): if caption.startswith(medium):
prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}" prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}"
else: else:
prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}"
return _truncate_to_fit(prompt, self.tokenize) return _truncate_to_fit(prompt)
def interrogate_fast(self, image: Image, max_flavors: int=32, caption: Optional[str]=None) -> str: def interrogate_fast(self, image: Image) -> str:
"""Fast mode simply adds the top ranked terms after a caption. It generally results in caption = self.generate_caption(image)
better similarity between generated prompt and image than classic mode, but the prompts
are less readable."""
caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str:
"""Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2."""
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count, reverse=True) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
flaves = flaves + self.negative.labels tops = merged.rank(image_features, 32)
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain") return _truncate_to_fit(caption + ", " + ", ".join(tops))
def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32, caption: Optional[str]=None) -> str: def interrogate(self, image: Image) -> str:
caption = caption or self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self) flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count) best_medium = self.mediums.rank(image_features, 1)[0]
best_prompt, best_sim = caption, self.similarity(image_features, caption) best_artist = self.artists.rank(image_features, 1)[0]
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain") best_trending = self.trendings.rank(image_features, 1)[0]
best_movement = self.movements.rank(image_features, 1)[0]
fast_prompt = self.interrogate_fast(image, max_flavors, caption=caption)
classic_prompt = self.interrogate_classic(image, max_flavors, caption=caption) best_prompt = caption
candidates = [caption, classic_prompt, fast_prompt, best_prompt] best_sim = self.similarity(image_features, best_prompt)
return candidates[np.argmax(self.similarities(image_features, candidates))]
def check(addition: str) -> bool:
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: nonlocal best_prompt, best_sim
self._prepare_clip() prompt = best_prompt + ", " + addition
text_tokens = self.tokenize([text for text in text_array]).to(self.device) sim = self.similarity(image_features, prompt)
with torch.no_grad(), torch.cuda.amp.autocast(): if sim > best_sim:
text_features = self.clip_model.encode_text(text_tokens) best_sim = sim
text_features /= text_features.norm(dim=-1, keepdim=True) best_prompt = prompt
similarity = text_features @ image_features.T return True
if reverse: return False
similarity = -similarity
return text_array[similarity.argmax().item()] def check_multi_batch(opts: List[str]):
nonlocal best_prompt, best_sim
def similarity(self, image_features: torch.Tensor, text: str) -> float: prompts = []
self._prepare_clip() for i in range(2**len(opts)):
text_tokens = self.tokenize([text]).to(self.device) prompt = best_prompt
with torch.no_grad(), torch.cuda.amp.autocast(): for bit in range(len(opts)):
text_features = self.clip_model.encode_text(text_tokens) if i & (1 << bit):
text_features /= text_features.norm(dim=-1, keepdim=True) prompt += ", " + opts[bit]
similarity = text_features @ image_features.T prompts.append(prompt)
return similarity[0][0].item()
t = LabelTable(prompts, None, self.clip_model, self.config)
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]: best_prompt = t.rank(image_features, 1)[0]
self._prepare_clip() best_sim = self.similarity(image_features, best_prompt)
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): check_multi_batch([best_medium, best_artist, best_trending, best_movement])
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True) extended_flavors = set(flaves)
similarity = text_features @ image_features.T for _ in tqdm(range(25), desc="Flavor chain"):
return similarity.T[0].tolist() try:
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
def _prepare_caption(self): flave = best[len(best_prompt)+2:]
if self.config.clip_offload and not self.clip_offloaded: if not check(flave):
self.clip_model = self.clip_model.to('cpu') break
self.clip_offloaded = True extended_flavors.remove(flave)
if self.caption_offloaded: except:
self.caption_model = self.caption_model.to(self.device) # exceeded max prompt length
self.caption_offloaded = False break
def _prepare_clip(self): return best_prompt
if self.config.caption_offload and not self.caption_offloaded:
self.caption_model = self.caption_model.to('cpu') def rank_top(self, image_features, text_array: List[str]) -> str:
self.caption_offloaded = True text_tokens = clip.tokenize([text for text in text_array]).to(self.device)
if self.clip_offloaded: with torch.no_grad():
self.clip_model = self.clip_model.to(self.device) text_features = self.clip_model.encode_text(text_tokens).float()
self.clip_offloaded = False text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array)), device=self.device)
for i in range(image_features.shape[0]):
similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
_, top_labels = similarity.cpu().topk(1, dim=-1)
return text_array[top_labels[0][0].numpy()]
def similarity(self, image_features, text) -> np.float32:
text_tokens = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_features = self.clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
return similarity[0][0]
class LabelTable(): class LabelTable():
def __init__(self, labels:List[str], desc:str, ci: Interrogator): def __init__(self, labels:List[str], desc:str, clip_model, config: Config):
clip_model, config = ci.clip_model, ci.config
self.chunk_size = config.chunk_size self.chunk_size = config.chunk_size
self.config = config
self.device = config.device self.device = config.device
self.embeds = []
self.labels = labels self.labels = labels
self.tokenize = ci.tokenize self.embeds = []
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_')
self._load_cached(desc, hash, sanitized_name) cache_filepath = None
if config.cache_path is not None and desc is not None:
os.makedirs(config.cache_path, exist_ok=True)
sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_')
cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl")
if desc is not None and os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f:
data = pickle.load(f)
if data.get('hash') == hash:
self.labels = data['labels']
self.embeds = data['embeds']
if len(self.labels) != len(self.embeds): if len(self.labels) != len(self.embeds):
self.embeds = [] self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
text_tokens = self.tokenize(chunk).to(self.device) text_tokens = clip.tokenize(chunk).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad():
text_features = clip_model.encode_text(text_tokens) text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.half().cpu().numpy() text_features = text_features.half().cpu().numpy()
for i in range(text_features.shape[0]): for i in range(text_features.shape[0]):
self.embeds.append(text_features[i]) self.embeds.append(text_features[i])
if desc and self.config.cache_path: if cache_filepath is not None:
os.makedirs(self.config.cache_path, exist_ok=True) with open(cache_filepath, 'wb') as f:
cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") pickle.dump({
tensors = { "labels": self.labels,
"embeds": np.stack(self.embeds), "embeds": self.embeds,
"hash": np.array([ord(c) for c in hash], dtype=np.int8) "hash": hash,
} "model": config.clip_model_name
save_file(tensors, cache_filepath) }, f)
if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds]
def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool:
if self.config.cache_path is None or desc is None:
return False
cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
if self.config.download_cache and not os.path.exists(cached_safetensors):
download_url = CACHE_URL_BASE + f"{sanitized_name}_{desc}.safetensors"
try:
os.makedirs(self.config.cache_path, exist_ok=True)
_download_file(download_url, cached_safetensors, quiet=self.config.quiet)
except Exception as e:
print(f"Failed to download {download_url}")
print(e)
return False
if os.path.exists(cached_safetensors):
try:
tensors = load_file(cached_safetensors)
except Exception as e:
print(f"Failed to load {cached_safetensors}")
print(e)
return False
if 'hash' in tensors and 'embeds' in tensors:
if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)):
self.embeds = tensors['embeds']
if len(self.embeds.shape) == 2:
self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])]
return True
return False
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: def _rank(self, image_features, text_embeds, top_count=1):
top_count = min(top_count, len(text_embeds)) top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) similarity = torch.zeros((1, len(text_embeds))).to(self.device)
with torch.cuda.amp.autocast(): text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(self.device)
similarity = image_features @ text_embeds.T for i in range(image_features.shape[0]):
if reverse: similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1)
similarity = -similarity _, top_labels = similarity.cpu().topk(top_count, dim=-1)
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)] return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]: def rank(self, image_features, top_count=1) -> List[str]:
if len(self.labels) <= self.chunk_size: if len(self.labels) <= self.chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse) tops = self._rank(image_features, self.embeds, top_count=top_count)
return [self.labels[i] for i in tops] return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size)) num_chunks = int(math.ceil(len(self.labels)/self.chunk_size))
keep_per_chunk = int(self.chunk_size / num_chunks) keep_per_chunk = int(self.chunk_size / num_chunks)
top_labels, top_embeds = [], [] top_labels, top_embeds = [], []
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): for chunk_idx in tqdm(range(num_chunks)):
start = chunk_idx*self.chunk_size start = chunk_idx*self.chunk_size
stop = min(start+self.chunk_size, len(self.embeds)) stop = min(start+self.chunk_size, len(self.embeds))
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse) tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
top_labels.extend([self.labels[start+i] for i in tops]) top_labels.extend([self.labels[start+i] for i in tops])
top_embeds.extend([self.embeds[start+i] for i in tops]) top_embeds.extend([self.embeds[start+i] for i in tops])
@ -400,51 +281,23 @@ class LabelTable():
return [top_labels[i] for i in tops] return [top_labels[i] for i in tops]
def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet: bool = False): def _load_list(data_path, filename) -> List[str]:
r = requests.get(url, stream=True) with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
if r.status_code != 200: items = [line.strip() for line in f.readlines()]
return return items
file_size = int(r.headers.get("Content-Length", 0))
filename = url.split("/")[-1]
progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet)
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
progress.update(len(chunk))
progress.close()
def _merge_tables(tables: List[LabelTable], ci: Interrogator) -> LabelTable: def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, ci) m = LabelTable([], None, None, config)
for table in tables: for table in tables:
m.labels.extend(table.labels) m.labels.extend(table.labels)
m.embeds.extend(table.embeds) m.embeds.extend(table.embeds)
return m return m
def _prompt_at_max_len(text: str, tokenize) -> bool: def _truncate_to_fit(text: str) -> str:
tokens = tokenize([text]) while True:
return tokens[0][-1] != 0 try:
_ = clip.tokenize([text])
def _truncate_to_fit(text: str, tokenize) -> str: return text
parts = text.split(', ') except:
new_text = parts[0] text = ",".join(text.split(",")[:-1])
for part in parts[1:]:
if _prompt_at_max_len(new_text + part, tokenize):
break
new_text += ', ' + part
return new_text
def list_caption_models() -> List[str]:
return list(CAPTION_MODELS.keys())
def list_clip_models() -> List[str]:
return ['/'.join(x) for x in open_clip.list_pretrained()]
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
"""Load a list of strings from a file."""
if filename is not None:
data_path = os.path.join(data_path, filename)
with open(data_path, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items

41
clip_interrogator/data/negative.txt

@ -1,41 +0,0 @@
3d
b&w
bad anatomy
bad art
blur
blurry
cartoon
childish
close up
deformed
disconnected limbs
disfigured
disgusting
extra limb
extra limbs
floating limbs
grain
illustration
kitsch
long body
long neck
low quality
low-res
malformed hands
mangled
missing limb
mutated
mutation
mutilated
noisy
old
out of focus
over saturation
oversaturated
poorly drawn
poorly drawn face
poorly drawn hands
render
surreal
ugly
weird colors

20
cog.yaml

@ -1,16 +1,20 @@
build: build:
gpu: true gpu: true
cuda: "11.8" cuda: "11.3"
python_version: "3.10" python_version: "3.8"
system_packages: system_packages:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "Pillow==10.0.0" - "ipython==8.4.0"
- "safetensors==0.3.3" - "fairscale==0.4.12"
- "tqdm==4.66.1" - "transformers==4.21.2"
- "open_clip_torch==2.20.0" - "ftfy==6.1.1"
- "accelerate==0.22.0" - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
- "transformers==4.33.1" - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
run:
- pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip
- pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
predict: "predict.py:Predictor" predict: "predict.py:Predictor"

41
predict.py

@ -2,44 +2,17 @@ import sys
from PIL import Image from PIL import Image
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path
from clip_interrogator import Config, Interrogator sys.path.extend(["src/clip", "src/blip"])
from clip_interrogator import Interrogator, Config
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
self.ci = Interrogator(Config( config = Config(device="cuda:0", clip_model_name='ViT-L/14')
clip_model_name="ViT-L-14/openai", self.ci = Interrogator(config)
clip_model_path='cache',
device='cuda:0',
))
def predict( def predict(self, image: Path = Input(description="Input image")) -> str:
self,
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L-14/openai",
choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k", "ViT-bigG-14/laion2b_s39b_b160k"],
description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.",
),
mode: str = Input(
default="best",
choices=["best", "classic", "fast", "negative"],
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
),
) -> str:
"""Run a single prediction on the model""" """Run a single prediction on the model"""
image = Image.open(str(image)).convert("RGB") image = Image.open(str(image)).convert("RGB")
self.switch_model(clip_model_name) return self.ci.interrogate(image)
if mode == 'best':
return self.ci.interrogate(image)
elif mode == 'classic':
return self.ci.interrogate_classic(image)
elif mode == 'fast':
return self.ci.interrogate_fast(image)
elif mode == 'negative':
return self.ci.interrogate_negative(image)
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name
self.ci.load_clip_model()

8
requirements.txt

@ -1,9 +1,5 @@
torch>=1.13.0 torch
torchvision torchvision
Pillow Pillow
requests requests
safetensors tqdm
tqdm
open_clip_torch
accelerate
transformers>=4.27.1

23
run_cli.py

@ -1,11 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import clip
import csv import csv
import os import os
import requests import requests
import torch import torch
from PIL import Image from PIL import Image
from clip_interrogator import Interrogator, Config, list_clip_models from clip_interrogator import Interrogator, Config
def inference(ci, image, mode): def inference(ci, image, mode):
image = image.convert('RGB') image = image.convert('RGB')
@ -18,12 +19,10 @@ def inference(ci, image, mode):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--clip', default='ViT-L-14/openai', help='name of CLIP model to use') parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use')
parser.add_argument('-d', '--device', default='auto', help='device to use (auto, cuda or cpu)')
parser.add_argument('-f', '--folder', help='path to folder of images') parser.add_argument('-f', '--folder', help='path to folder of images')
parser.add_argument('-i', '--image', help='image file or url') parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM")
args = parser.parse_args() args = parser.parse_args()
if not args.folder and not args.image: if not args.folder and not args.image:
@ -35,24 +34,14 @@ def main():
exit(1) exit(1)
# validate clip model name # validate clip model name
models = list_clip_models() if args.clip not in clip.available_models():
if args.clip not in models:
print(f"Could not find CLIP model {args.clip}!") print(f"Could not find CLIP model {args.clip}!")
print(f" available models: {models}") print(f" available models: {clip.available_models()}")
exit(1) exit(1)
# select device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available():
print("CUDA is not available, using CPU. Warning: this will be very slow!")
else:
device = torch.device(args.device)
# generate a nice prompt # generate a nice prompt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Config(device=device, clip_model_name=args.clip) config = Config(device=device, clip_model_name=args.clip)
if args.lowvram:
config.apply_low_vram_defaults()
ci = Interrogator(config) ci = Interrogator(config)
# process single image # process single image

118
run_gradio.py

@ -1,99 +1,41 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import clip
import torch import gradio as gr
from clip_interrogator import Config, Interrogator, list_caption_models, list_clip_models from clip_interrogator import Interrogator, Config
try: ci = Interrogator(Config())
import gradio as gr
except ImportError:
print("Gradio is not installed, please install it with 'pip install gradio'")
exit(1)
parser = argparse.ArgumentParser()
parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM")
parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
args = parser.parse_args()
if not torch.cuda.is_available():
print("CUDA is not available, using CPU. Warning: this will be very slow!")
config = Config(cache_path="cache")
if args.lowvram:
config.apply_low_vram_defaults()
ci = Interrogator(config)
def image_analysis(image, clip_model_name):
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
image = image.convert('RGB')
image_features = ci.image_to_features(image)
top_mediums = ci.mediums.rank(image_features, 5)
top_artists = ci.artists.rank(image_features, 5)
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def image_to_prompt(image, mode, clip_model_name, blip_model_name):
if blip_model_name != ci.config.caption_model_name:
ci.config.caption_model_name = blip_model_name
ci.load_caption_model()
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
global ci
if clip_model_name != ci.config.clip_model_name: if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name ci = Interrogator(Config(clip_model_name=clip_model_name))
ci.load_clip_model() ci.config.blip_max_length = int(blip_max_length)
ci.config.blip_num_beams = int(blip_num_beams)
image = image.convert('RGB') image = image.convert('RGB')
if mode == 'best': if mode == 'best':
return ci.interrogate(image) return ci.interrogate(image)
elif mode == 'classic': elif mode == 'classic':
return ci.interrogate_classic(image) return ci.interrogate_classic(image)
elif mode == 'fast': else:
return ci.interrogate_fast(image) return ci.interrogate_fast(image)
elif mode == 'negative':
return ci.interrogate_negative(image) inputs = [
gr.inputs.Image(type='pil'),
def prompt_tab(): gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'),
with gr.Column(): gr.Dropdown(clip.available_models(), value='ViT-L/14', label='CLIP Model'),
with gr.Row(): gr.Number(value=32, label='Caption Max Length'),
image = gr.Image(type='pil', label="Image") gr.Number(value=64, label='Caption Num Beams'),
with gr.Column(): ]
mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best') outputs = [
clip_model = gr.Dropdown(list_clip_models(), value=ci.config.clip_model_name, label='CLIP Model') gr.outputs.Textbox(label="Output"),
blip_model = gr.Dropdown(list_caption_models(), value=ci.config.caption_model_name, label='Caption Model') ]
prompt = gr.Textbox(label="Prompt")
button = gr.Button("Generate prompt") io = gr.Interface(
button.click(image_to_prompt, inputs=[image, mode, clip_model, blip_model], outputs=prompt) inference,
inputs,
def analyze_tab(): outputs,
with gr.Column(): title="🕵 CLIP Interrogator 🕵",
with gr.Row(): allow_flagging=False,
image = gr.Image(type='pil', label="Image") )
model = gr.Dropdown(list_clip_models(), value='ViT-L-14/openai', label='CLIP Model') io.launch()
with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)
movement = gr.Label(label="Movement", num_top_classes=5)
trending = gr.Label(label="Trending", num_top_classes=5)
flavor = gr.Label(label="Flavor", num_top_classes=5)
button = gr.Button("Analyze")
button.click(image_analysis, inputs=[image, model], outputs=[medium, artist, movement, trending, flavor])
with gr.Blocks() as ui:
gr.Markdown("# <center>🕵 CLIP Interrogator 🕵</center>")
with gr.Tab("Prompt"):
prompt_tab()
with gr.Tab("Analyze"):
analyze_tab()
ui.launch(show_api=False, debug=True, share=args.share)

4
setup.py

@ -5,13 +5,13 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.6.0", version="0.1.4",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',
url='https://github.com/pharmapsychotic/clip-interrogator', url='https://github.com/pharmapsychotic/clip-interrogator',
description="Generate a prompt from an image", description="Generate a prompt from an image",
long_description=open('README.md', encoding='utf-8').read(), long_description=open('README.md').read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=[

Loading…
Cancel
Save