Browse Source

Support for many different caption models:

blip-base, blip-large, blip2-2.7b, blip2-flan-t5-xl, git-large-coco
pull/69/head v0.6.0
pharmapsychotic 2 years ago
parent
commit
ce9d271aa1
  1. 53
      clip_interrogator.ipynb
  2. 4
      clip_interrogator/__init__.py
  3. 130
      clip_interrogator/clip_interrogator.py
  4. 4
      requirements.txt
  5. 5
      run_cli.py
  6. 19
      run_gradio.py
  7. 2
      setup.py

53
clip_interrogator.ipynb

@ -7,7 +7,7 @@
"id": "3jm8RYrLqvzz"
},
"source": [
"# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"# CLIP Interrogator 2.4 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n",
@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "aP9FjmWxtLKJ"
@ -42,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xpPKQR40qvz2"
@ -54,8 +54,7 @@
"\n",
"def setup():\n",
" install_cmds = [\n",
" ['pip', 'install', 'transformers==4.15.0'],\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'clip-interrogator'],\n",
" ]\n",
@ -65,16 +64,15 @@
"setup()\n",
"\n",
"\n",
"caption_model_name = 'blip-large' #@param [\"blip-base\", \"blip-large\", \"git-large-coco\"]\n",
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"\n",
"import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n",
"\n",
"config = Config()\n",
"config.blip_num_beams = 64\n",
"config.blip_offload = False\n",
"config.clip_model_name = clip_model_name\n",
"config.caption_model_name = caption_model_name\n",
"ci = Interrogator(config)\n",
"\n",
"def image_analysis(image):\n",
@ -112,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {
@ -122,40 +120,7 @@
"id": "Pf6qkFG6MPRj",
"outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
"\n",
"Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n",
"Note: opening the browser inspector may crash Embedded Colab Mode.\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n https://localhost:${port}${path}\n </a>\n </div>\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)",
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(<gradio.routes.App at 0x7f894e553710>, 'http://127.0.0.1:7860/', None)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"#@title Image to prompt! 🖼 -> 📝\n",
" \n",
@ -291,7 +256,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]"
"version": "3.9.5"
},
"orig_nbformat": 4,
"vscode": {

4
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Config, Interrogator, LabelTable, load_list
from .clip_interrogator import Config, Interrogator, LabelTable, list_caption_models, list_clip_models, load_list
__version__ = '0.5.5'
__version__ = '0.6.0'
__author__ = 'pharmapsychotic'

130
clip_interrogator/clip_interrogator.py

@ -1,5 +1,4 @@
import hashlib
import inspect
import math
import numpy as np
import open_clip
@ -9,18 +8,19 @@ import time
import torch
from dataclasses import dataclass
from blip.models.blip import blip_decoder, BLIP_Decoder
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration
from tqdm import tqdm
from typing import List, Optional
from safetensors.numpy import load_file, save_file
BLIP_MODELS = {
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB
'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB
'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
'git-large-coco': 'microsoft/git-large-coco', # 1.58GB
}
CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@ -29,16 +29,15 @@ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@dataclass
class Config:
# models can optionally be passed in directly
blip_model: Optional[BLIP_Decoder] = None
caption_model = None
caption_processor = None
clip_model = None
clip_preprocess = None
# blip settings
blip_image_eval_size: int = 384
blip_max_length: int = 32
blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None
blip_num_beams: int = 8
blip_offload: bool = False
caption_max_length: int = 32
caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None
caption_offload: bool = False
# clip settings
clip_model_name: str = 'ViT-L-14/openai'
@ -55,8 +54,8 @@ class Config:
quiet: bool = False # when quiet progress bars are not shown
def apply_low_vram_defaults(self):
self.blip_model_type = 'base'
self.blip_offload = True
self.caption_model_name = 'blip-base'
self.caption_offload = True
self.clip_offload = True
self.chunk_size = 1024
self.flavor_intermediate_count = 1024
@ -65,29 +64,33 @@ class Interrogator():
def __init__(self, config: Config):
self.config = config
self.device = config.device
self.blip_offloaded = True
self.dtype = torch.float16 if self.device == 'cuda' else torch.float32
self.caption_offloaded = True
self.clip_offloaded = True
self.load_caption_model()
self.load_clip_model()
if config.blip_model is None and config.blip_model_type:
if not config.quiet:
print("Loading BLIP model...")
blip_path = os.path.dirname(inspect.getfile(blip_decoder))
configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
med_config = os.path.join(configs_path, 'med_config.json')
blip_model = blip_decoder(
pretrained=BLIP_MODELS[config.blip_model_type],
image_size=config.blip_image_eval_size,
vit=config.blip_model_type,
med_config=med_config
)
blip_model.eval()
if not self.config.blip_offload:
blip_model = blip_model.to(config.device)
self.blip_model = blip_model
def load_caption_model(self):
if self.config.caption_model is None and self.config.caption_model_name:
if not self.config.quiet:
print(f"Loading caption model {self.config.caption_model_name}...")
model_path = CAPTION_MODELS[self.config.caption_model_name]
if self.config.caption_model_name.startswith('git-'):
caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
elif self.config.caption_model_name.startswith('blip2-'):
caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
else:
caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
self.caption_processor = AutoProcessor.from_pretrained(model_path)
caption_model.eval()
if not self.config.caption_offload:
caption_model = caption_model.to(self.config.device)
self.caption_model = caption_model
else:
self.blip_model = config.blip_model
self.load_clip_model()
self.caption_model = self.config.caption_model
self.caption_processor = self.config.caption_processor
def load_clip_model(self):
start_time = time.time()
@ -97,7 +100,7 @@ class Interrogator():
if config.clip_model is None:
if not config.quiet:
print("Loading CLIP model...")
print(f"Loading CLIP model {config.clip_model_name}...")
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name,
@ -183,26 +186,13 @@ class Interrogator():
return best_prompt
def generate_caption(self, pil_image: Image) -> str:
assert self.blip_model is not None, "No BLIP model loaded."
self._prepare_blip()
size = self.config.blip_image_eval_size
gpu_image = transforms.Compose([
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(self.device)
with torch.no_grad():
caption = self.blip_model.generate(
gpu_image,
sample=False,
num_beams=self.config.blip_num_beams,
max_length=self.config.blip_max_length,
min_length=5
)
return caption[0]
assert self.caption_model is not None, "No caption model loaded."
self._prepare_caption()
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device)
if not self.config.caption_model_name.startswith('git-'):
inputs = inputs.to(self.dtype)
tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length)
return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip()
def image_to_features(self, image: Image) -> torch.Tensor:
self._prepare_clip()
@ -237,7 +227,7 @@ class Interrogator():
are less readable."""
caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
@ -254,7 +244,7 @@ class Interrogator():
caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt, best_sim = caption, self.similarity(image_features, caption)
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain")
@ -293,18 +283,18 @@ class Interrogator():
similarity = text_features @ image_features.T
return similarity.T[0].tolist()
def _prepare_blip(self):
def _prepare_caption(self):
if self.config.clip_offload and not self.clip_offloaded:
self.clip_model = self.clip_model.to('cpu')
self.clip_offloaded = True
if self.blip_offloaded:
self.blip_model = self.blip_model.to(self.device)
self.blip_offloaded = False
if self.caption_offloaded:
self.caption_model = self.caption_model.to(self.device)
self.caption_offloaded = False
def _prepare_clip(self):
if self.config.blip_offload and not self.blip_offloaded:
self.blip_model = self.blip_model.to('cpu')
self.blip_offloaded = True
if self.config.caption_offload and not self.caption_offloaded:
self.caption_model = self.caption_model.to('cpu')
self.caption_offloaded = True
if self.clip_offloaded:
self.clip_model = self.clip_model.to(self.device)
self.clip_offloaded = False
@ -425,8 +415,8 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
progress.update(len(chunk))
progress.close()
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, None, config)
def _merge_tables(tables: List[LabelTable], ci: Interrogator) -> LabelTable:
m = LabelTable([], None, ci)
for table in tables:
m.labels.extend(table.labels)
m.embeds.extend(table.embeds)
@ -445,6 +435,12 @@ def _truncate_to_fit(text: str, tokenize) -> str:
new_text += ', ' + part
return new_text
def list_caption_models() -> List[str]:
return list(CAPTION_MODELS.keys())
def list_clip_models() -> List[str]:
return ['/'.join(x) for x in open_clip.list_pretrained()]
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
"""Load a list of strings from a file."""
if filename is not None:

4
requirements.txt

@ -5,5 +5,5 @@ requests
safetensors
tqdm
open_clip_torch
blip-ci
transformers>=4.15.0,<=4.26.1
accelerate
transformers>=4.27.1

5
run_cli.py

@ -1,12 +1,11 @@
#!/usr/bin/env python3
import argparse
import csv
import open_clip
import os
import requests
import torch
from PIL import Image
from clip_interrogator import Interrogator, Config
from clip_interrogator import Interrogator, Config, list_clip_models
def inference(ci, image, mode):
image = image.convert('RGB')
@ -36,7 +35,7 @@ def main():
exit(1)
# validate clip model name
models = ['/'.join(x) for x in open_clip.list_pretrained()]
models = list_clip_models()
if args.clip not in models:
print(f"Could not find CLIP model {args.clip}!")
print(f" available models: {models}")

19
run_gradio.py

@ -1,8 +1,7 @@
#!/usr/bin/env python3
import argparse
import open_clip
import torch
from clip_interrogator import Config, Interrogator
from clip_interrogator import Config, Interrogator, list_caption_models, list_clip_models
try:
import gradio as gr
@ -45,7 +44,11 @@ def image_analysis(image, clip_model_name):
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def image_to_prompt(image, mode, clip_model_name):
def image_to_prompt(image, mode, clip_model_name, blip_model_name):
if blip_model_name != ci.config.caption_model_name:
ci.config.caption_model_name = blip_model_name
ci.load_caption_model()
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
@ -60,25 +63,23 @@ def image_to_prompt(image, mode, clip_model_name):
elif mode == 'negative':
return ci.interrogate_negative(image)
models = ['/'.join(x) for x in open_clip.list_pretrained()]
def prompt_tab():
with gr.Column():
with gr.Row():
image = gr.Image(type='pil', label="Image")
with gr.Column():
mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
clip_model = gr.Dropdown(list_clip_models(), value=ci.config.clip_model_name, label='CLIP Model')
blip_model = gr.Dropdown(list_caption_models(), value=ci.config.caption_model_name, label='Caption Model')
prompt = gr.Textbox(label="Prompt")
button = gr.Button("Generate prompt")
button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt)
button.click(image_to_prompt, inputs=[image, mode, clip_model, blip_model], outputs=prompt)
def analyze_tab():
with gr.Column():
with gr.Row():
image = gr.Image(type='pil', label="Image")
model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
model = gr.Dropdown(list_clip_models(), value='ViT-L-14/openai', label='CLIP Model')
with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup(
name="clip-interrogator",
version="0.5.5",
version="0.6.0",
license='MIT',
author='pharmapsychotic',
author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save