Browse Source

Support for many different caption models:

blip-base, blip-large, blip2-2.7b, blip2-flan-t5-xl, git-large-coco
pull/69/head v0.6.0
pharmapsychotic 2 years ago
parent
commit
ce9d271aa1
  1. 53
      clip_interrogator.ipynb
  2. 4
      clip_interrogator/__init__.py
  3. 130
      clip_interrogator/clip_interrogator.py
  4. 4
      requirements.txt
  5. 5
      run_cli.py
  6. 19
      run_gradio.py
  7. 2
      setup.py

53
clip_interrogator.ipynb

@ -7,7 +7,7 @@
"id": "3jm8RYrLqvzz" "id": "3jm8RYrLqvzz"
}, },
"source": [ "source": [
"# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", "# CLIP Interrogator 2.4 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n", "\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n", "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n", "\n",
@ -29,7 +29,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"id": "aP9FjmWxtLKJ" "id": "aP9FjmWxtLKJ"
@ -42,7 +42,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"id": "xpPKQR40qvz2" "id": "xpPKQR40qvz2"
@ -54,8 +54,7 @@
"\n", "\n",
"def setup():\n", "def setup():\n",
" install_cmds = [\n", " install_cmds = [\n",
" ['pip', 'install', 'transformers==4.15.0'],\n", " ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n", " ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'clip-interrogator'],\n", " ['pip', 'install', 'clip-interrogator'],\n",
" ]\n", " ]\n",
@ -65,16 +64,15 @@
"setup()\n", "setup()\n",
"\n", "\n",
"\n", "\n",
"caption_model_name = 'blip-large' #@param [\"blip-base\", \"blip-large\", \"git-large-coco\"]\n",
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n", "clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n", "\n",
"\n",
"import gradio as gr\n", "import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n", "from clip_interrogator import Config, Interrogator\n",
"\n", "\n",
"config = Config()\n", "config = Config()\n",
"config.blip_num_beams = 64\n",
"config.blip_offload = False\n",
"config.clip_model_name = clip_model_name\n", "config.clip_model_name = clip_model_name\n",
"config.caption_model_name = caption_model_name\n",
"ci = Interrogator(config)\n", "ci = Interrogator(config)\n",
"\n", "\n",
"def image_analysis(image):\n", "def image_analysis(image):\n",
@ -112,7 +110,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": { "colab": {
@ -122,40 +120,7 @@
"id": "Pf6qkFG6MPRj", "id": "Pf6qkFG6MPRj",
"outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d" "outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
"\n",
"Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n",
"Note: opening the browser inspector may crash Embedded Colab Mode.\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n https://localhost:${port}${path}\n </a>\n </div>\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)",
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(<gradio.routes.App at 0x7f894e553710>, 'http://127.0.0.1:7860/', None)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"#@title Image to prompt! 🖼 -> 📝\n", "#@title Image to prompt! 🖼 -> 📝\n",
" \n", " \n",
@ -291,7 +256,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]" "version": "3.9.5"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {

4
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Config, Interrogator, LabelTable, load_list from .clip_interrogator import Config, Interrogator, LabelTable, list_caption_models, list_clip_models, load_list
__version__ = '0.5.5' __version__ = '0.6.0'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

130
clip_interrogator/clip_interrogator.py

@ -1,5 +1,4 @@
import hashlib import hashlib
import inspect
import math import math
import numpy as np import numpy as np
import open_clip import open_clip
@ -9,18 +8,19 @@ import time
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from blip.models.blip import blip_decoder, BLIP_Decoder
from PIL import Image from PIL import Image
from torchvision import transforms from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm from tqdm import tqdm
from typing import List, Optional from typing import List, Optional
from safetensors.numpy import load_file, save_file from safetensors.numpy import load_file, save_file
BLIP_MODELS = { CAPTION_MODELS = {
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', 'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' 'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB
'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
'git-large-coco': 'microsoft/git-large-coco', # 1.58GB
} }
CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/' CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@ -29,16 +29,15 @@ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@dataclass @dataclass
class Config: class Config:
# models can optionally be passed in directly # models can optionally be passed in directly
blip_model: Optional[BLIP_Decoder] = None caption_model = None
caption_processor = None
clip_model = None clip_model = None
clip_preprocess = None clip_preprocess = None
# blip settings # blip settings
blip_image_eval_size: int = 384 caption_max_length: int = 32
blip_max_length: int = 32 caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None
blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None caption_offload: bool = False
blip_num_beams: int = 8
blip_offload: bool = False
# clip settings # clip settings
clip_model_name: str = 'ViT-L-14/openai' clip_model_name: str = 'ViT-L-14/openai'
@ -55,8 +54,8 @@ class Config:
quiet: bool = False # when quiet progress bars are not shown quiet: bool = False # when quiet progress bars are not shown
def apply_low_vram_defaults(self): def apply_low_vram_defaults(self):
self.blip_model_type = 'base' self.caption_model_name = 'blip-base'
self.blip_offload = True self.caption_offload = True
self.clip_offload = True self.clip_offload = True
self.chunk_size = 1024 self.chunk_size = 1024
self.flavor_intermediate_count = 1024 self.flavor_intermediate_count = 1024
@ -65,29 +64,33 @@ class Interrogator():
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
self.device = config.device self.device = config.device
self.blip_offloaded = True self.dtype = torch.float16 if self.device == 'cuda' else torch.float32
self.caption_offloaded = True
self.clip_offloaded = True self.clip_offloaded = True
self.load_caption_model()
self.load_clip_model()
if config.blip_model is None and config.blip_model_type: def load_caption_model(self):
if not config.quiet: if self.config.caption_model is None and self.config.caption_model_name:
print("Loading BLIP model...") if not self.config.quiet:
blip_path = os.path.dirname(inspect.getfile(blip_decoder)) print(f"Loading caption model {self.config.caption_model_name}...")
configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
med_config = os.path.join(configs_path, 'med_config.json') model_path = CAPTION_MODELS[self.config.caption_model_name]
blip_model = blip_decoder( if self.config.caption_model_name.startswith('git-'):
pretrained=BLIP_MODELS[config.blip_model_type], caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
image_size=config.blip_image_eval_size, elif self.config.caption_model_name.startswith('blip2-'):
vit=config.blip_model_type, caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
med_config=med_config else:
) caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
blip_model.eval() self.caption_processor = AutoProcessor.from_pretrained(model_path)
if not self.config.blip_offload:
blip_model = blip_model.to(config.device) caption_model.eval()
self.blip_model = blip_model if not self.config.caption_offload:
caption_model = caption_model.to(self.config.device)
self.caption_model = caption_model
else: else:
self.blip_model = config.blip_model self.caption_model = self.config.caption_model
self.caption_processor = self.config.caption_processor
self.load_clip_model()
def load_clip_model(self): def load_clip_model(self):
start_time = time.time() start_time = time.time()
@ -97,7 +100,7 @@ class Interrogator():
if config.clip_model is None: if config.clip_model is None:
if not config.quiet: if not config.quiet:
print("Loading CLIP model...") print(f"Loading CLIP model {config.clip_model_name}...")
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name, clip_model_name,
@ -183,26 +186,13 @@ class Interrogator():
return best_prompt return best_prompt
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
assert self.blip_model is not None, "No BLIP model loaded." assert self.caption_model is not None, "No caption model loaded."
self._prepare_blip() self._prepare_caption()
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device)
size = self.config.blip_image_eval_size if not self.config.caption_model_name.startswith('git-'):
gpu_image = transforms.Compose([ inputs = inputs.to(self.dtype)
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length)
transforms.ToTensor(), return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip()
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(self.device)
with torch.no_grad():
caption = self.blip_model.generate(
gpu_image,
sample=False,
num_beams=self.config.blip_num_beams,
max_length=self.config.blip_max_length,
min_length=5
)
return caption[0]
def image_to_features(self, image: Image) -> torch.Tensor: def image_to_features(self, image: Image) -> torch.Tensor:
self._prepare_clip() self._prepare_clip()
@ -237,7 +227,7 @@ class Interrogator():
are less readable.""" are less readable."""
caption = caption or self.generate_caption(image) caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
tops = merged.rank(image_features, max_flavors) tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
@ -254,7 +244,7 @@ class Interrogator():
caption = caption or self.generate_caption(image) caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count) flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt, best_sim = caption, self.similarity(image_features, caption) best_prompt, best_sim = caption, self.similarity(image_features, caption)
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain") best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain")
@ -293,18 +283,18 @@ class Interrogator():
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
return similarity.T[0].tolist() return similarity.T[0].tolist()
def _prepare_blip(self): def _prepare_caption(self):
if self.config.clip_offload and not self.clip_offloaded: if self.config.clip_offload and not self.clip_offloaded:
self.clip_model = self.clip_model.to('cpu') self.clip_model = self.clip_model.to('cpu')
self.clip_offloaded = True self.clip_offloaded = True
if self.blip_offloaded: if self.caption_offloaded:
self.blip_model = self.blip_model.to(self.device) self.caption_model = self.caption_model.to(self.device)
self.blip_offloaded = False self.caption_offloaded = False
def _prepare_clip(self): def _prepare_clip(self):
if self.config.blip_offload and not self.blip_offloaded: if self.config.caption_offload and not self.caption_offloaded:
self.blip_model = self.blip_model.to('cpu') self.caption_model = self.caption_model.to('cpu')
self.blip_offloaded = True self.caption_offloaded = True
if self.clip_offloaded: if self.clip_offloaded:
self.clip_model = self.clip_model.to(self.device) self.clip_model = self.clip_model.to(self.device)
self.clip_offloaded = False self.clip_offloaded = False
@ -425,8 +415,8 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
progress.update(len(chunk)) progress.update(len(chunk))
progress.close() progress.close()
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: def _merge_tables(tables: List[LabelTable], ci: Interrogator) -> LabelTable:
m = LabelTable([], None, None, None, config) m = LabelTable([], None, ci)
for table in tables: for table in tables:
m.labels.extend(table.labels) m.labels.extend(table.labels)
m.embeds.extend(table.embeds) m.embeds.extend(table.embeds)
@ -445,6 +435,12 @@ def _truncate_to_fit(text: str, tokenize) -> str:
new_text += ', ' + part new_text += ', ' + part
return new_text return new_text
def list_caption_models() -> List[str]:
return list(CAPTION_MODELS.keys())
def list_clip_models() -> List[str]:
return ['/'.join(x) for x in open_clip.list_pretrained()]
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]: def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
"""Load a list of strings from a file.""" """Load a list of strings from a file."""
if filename is not None: if filename is not None:

4
requirements.txt

@ -5,5 +5,5 @@ requests
safetensors safetensors
tqdm tqdm
open_clip_torch open_clip_torch
blip-ci accelerate
transformers>=4.15.0,<=4.26.1 transformers>=4.27.1

5
run_cli.py

@ -1,12 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import csv import csv
import open_clip
import os import os
import requests import requests
import torch import torch
from PIL import Image from PIL import Image
from clip_interrogator import Interrogator, Config from clip_interrogator import Interrogator, Config, list_clip_models
def inference(ci, image, mode): def inference(ci, image, mode):
image = image.convert('RGB') image = image.convert('RGB')
@ -36,7 +35,7 @@ def main():
exit(1) exit(1)
# validate clip model name # validate clip model name
models = ['/'.join(x) for x in open_clip.list_pretrained()] models = list_clip_models()
if args.clip not in models: if args.clip not in models:
print(f"Could not find CLIP model {args.clip}!") print(f"Could not find CLIP model {args.clip}!")
print(f" available models: {models}") print(f" available models: {models}")

19
run_gradio.py

@ -1,8 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import open_clip
import torch import torch
from clip_interrogator import Config, Interrogator from clip_interrogator import Config, Interrogator, list_caption_models, list_clip_models
try: try:
import gradio as gr import gradio as gr
@ -45,7 +44,11 @@ def image_analysis(image, clip_model_name):
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def image_to_prompt(image, mode, clip_model_name): def image_to_prompt(image, mode, clip_model_name, blip_model_name):
if blip_model_name != ci.config.caption_model_name:
ci.config.caption_model_name = blip_model_name
ci.load_caption_model()
if clip_model_name != ci.config.clip_model_name: if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name ci.config.clip_model_name = clip_model_name
ci.load_clip_model() ci.load_clip_model()
@ -60,25 +63,23 @@ def image_to_prompt(image, mode, clip_model_name):
elif mode == 'negative': elif mode == 'negative':
return ci.interrogate_negative(image) return ci.interrogate_negative(image)
models = ['/'.join(x) for x in open_clip.list_pretrained()]
def prompt_tab(): def prompt_tab():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
image = gr.Image(type='pil', label="Image") image = gr.Image(type='pil', label="Image")
with gr.Column(): with gr.Column():
mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best') mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model') clip_model = gr.Dropdown(list_clip_models(), value=ci.config.clip_model_name, label='CLIP Model')
blip_model = gr.Dropdown(list_caption_models(), value=ci.config.caption_model_name, label='Caption Model')
prompt = gr.Textbox(label="Prompt") prompt = gr.Textbox(label="Prompt")
button = gr.Button("Generate prompt") button = gr.Button("Generate prompt")
button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt) button.click(image_to_prompt, inputs=[image, mode, clip_model, blip_model], outputs=prompt)
def analyze_tab(): def analyze_tab():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
image = gr.Image(type='pil', label="Image") image = gr.Image(type='pil', label="Image")
model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model') model = gr.Dropdown(list_clip_models(), value='ViT-L-14/openai', label='CLIP Model')
with gr.Row(): with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5) medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5) artist = gr.Label(label="Artist", num_top_classes=5)

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.5.5", version="0.6.0",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save