diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb
index 440fc2c..e56fdef 100755
--- a/clip_interrogator.ipynb
+++ b/clip_interrogator.ipynb
@@ -7,7 +7,7 @@
"id": "3jm8RYrLqvzz"
},
"source": [
- "# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
+ "# CLIP Interrogator 2.4 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n",
@@ -29,7 +29,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"cellView": "form",
"id": "aP9FjmWxtLKJ"
@@ -42,7 +42,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {
"cellView": "form",
"id": "xpPKQR40qvz2"
@@ -54,8 +54,7 @@
"\n",
"def setup():\n",
" install_cmds = [\n",
- " ['pip', 'install', 'transformers==4.15.0'],\n",
- " ['pip', 'install', 'gradio'],\n",
+ " ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'clip-interrogator'],\n",
" ]\n",
@@ -65,16 +64,15 @@
"setup()\n",
"\n",
"\n",
+ "caption_model_name = 'blip-large' #@param [\"blip-base\", \"blip-large\", \"git-large-coco\"]\n",
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
- "\n",
"import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n",
"\n",
"config = Config()\n",
- "config.blip_num_beams = 64\n",
- "config.blip_offload = False\n",
"config.clip_model_name = clip_model_name\n",
+ "config.caption_model_name = caption_model_name\n",
"ci = Interrogator(config)\n",
"\n",
"def image_analysis(image):\n",
@@ -112,7 +110,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {
"cellView": "form",
"colab": {
@@ -122,40 +120,7 @@
"id": "Pf6qkFG6MPRj",
"outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
- "\n",
- "Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n",
- "Note: opening the browser inspector may crash Embedded Colab Mode.\n",
- "\n",
- "To create a public link, set `share=True` in `launch()`.\n"
- ]
- },
- {
- "data": {
- "application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n
\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)",
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/plain": [
- "(, 'http://127.0.0.1:7860/', None)"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"#@title Image to prompt! 🖼️ -> 📝\n",
" \n",
@@ -291,7 +256,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]"
+ "version": "3.9.5"
},
"orig_nbformat": 4,
"vscode": {
diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py
index 9a2936a..d560ce3 100644
--- a/clip_interrogator/__init__.py
+++ b/clip_interrogator/__init__.py
@@ -1,4 +1,4 @@
-from .clip_interrogator import Config, Interrogator, LabelTable, load_list
+from .clip_interrogator import Config, Interrogator, LabelTable, list_caption_models, list_clip_models, load_list
-__version__ = '0.5.5'
+__version__ = '0.6.0'
__author__ = 'pharmapsychotic'
\ No newline at end of file
diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py
index 5d936fe..e7fcb5a 100644
--- a/clip_interrogator/clip_interrogator.py
+++ b/clip_interrogator/clip_interrogator.py
@@ -1,5 +1,4 @@
import hashlib
-import inspect
import math
import numpy as np
import open_clip
@@ -9,18 +8,19 @@ import time
import torch
from dataclasses import dataclass
-from blip.models.blip import blip_decoder, BLIP_Decoder
from PIL import Image
-from torchvision import transforms
-from torchvision.transforms.functional import InterpolationMode
+from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration
from tqdm import tqdm
from typing import List, Optional
from safetensors.numpy import load_file, save_file
-BLIP_MODELS = {
- 'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
- 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
+CAPTION_MODELS = {
+ 'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB
+ 'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB
+ 'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB
+ 'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
+ 'git-large-coco': 'microsoft/git-large-coco', # 1.58GB
}
CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@@ -29,16 +29,15 @@ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@dataclass
class Config:
# models can optionally be passed in directly
- blip_model: Optional[BLIP_Decoder] = None
+ caption_model = None
+ caption_processor = None
clip_model = None
clip_preprocess = None
# blip settings
- blip_image_eval_size: int = 384
- blip_max_length: int = 32
- blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None
- blip_num_beams: int = 8
- blip_offload: bool = False
+ caption_max_length: int = 32
+ caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None
+ caption_offload: bool = False
# clip settings
clip_model_name: str = 'ViT-L-14/openai'
@@ -55,8 +54,8 @@ class Config:
quiet: bool = False # when quiet progress bars are not shown
def apply_low_vram_defaults(self):
- self.blip_model_type = 'base'
- self.blip_offload = True
+ self.caption_model_name = 'blip-base'
+ self.caption_offload = True
self.clip_offload = True
self.chunk_size = 1024
self.flavor_intermediate_count = 1024
@@ -65,29 +64,33 @@ class Interrogator():
def __init__(self, config: Config):
self.config = config
self.device = config.device
- self.blip_offloaded = True
+ self.dtype = torch.float16 if self.device == 'cuda' else torch.float32
+ self.caption_offloaded = True
self.clip_offloaded = True
+ self.load_caption_model()
+ self.load_clip_model()
- if config.blip_model is None and config.blip_model_type:
- if not config.quiet:
- print("Loading BLIP model...")
- blip_path = os.path.dirname(inspect.getfile(blip_decoder))
- configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
- med_config = os.path.join(configs_path, 'med_config.json')
- blip_model = blip_decoder(
- pretrained=BLIP_MODELS[config.blip_model_type],
- image_size=config.blip_image_eval_size,
- vit=config.blip_model_type,
- med_config=med_config
- )
- blip_model.eval()
- if not self.config.blip_offload:
- blip_model = blip_model.to(config.device)
- self.blip_model = blip_model
+ def load_caption_model(self):
+ if self.config.caption_model is None and self.config.caption_model_name:
+ if not self.config.quiet:
+ print(f"Loading caption model {self.config.caption_model_name}...")
+
+ model_path = CAPTION_MODELS[self.config.caption_model_name]
+ if self.config.caption_model_name.startswith('git-'):
+ caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
+ elif self.config.caption_model_name.startswith('blip2-'):
+ caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
+ else:
+ caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
+ self.caption_processor = AutoProcessor.from_pretrained(model_path)
+
+ caption_model.eval()
+ if not self.config.caption_offload:
+ caption_model = caption_model.to(self.config.device)
+ self.caption_model = caption_model
else:
- self.blip_model = config.blip_model
-
- self.load_clip_model()
+ self.caption_model = self.config.caption_model
+ self.caption_processor = self.config.caption_processor
def load_clip_model(self):
start_time = time.time()
@@ -97,7 +100,7 @@ class Interrogator():
if config.clip_model is None:
if not config.quiet:
- print("Loading CLIP model...")
+ print(f"Loading CLIP model {config.clip_model_name}...")
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name,
@@ -183,26 +186,13 @@ class Interrogator():
return best_prompt
def generate_caption(self, pil_image: Image) -> str:
- assert self.blip_model is not None, "No BLIP model loaded."
- self._prepare_blip()
-
- size = self.config.blip_image_eval_size
- gpu_image = transforms.Compose([
- transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
- transforms.ToTensor(),
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).to(self.device)
-
- with torch.no_grad():
- caption = self.blip_model.generate(
- gpu_image,
- sample=False,
- num_beams=self.config.blip_num_beams,
- max_length=self.config.blip_max_length,
- min_length=5
- )
-
- return caption[0]
+ assert self.caption_model is not None, "No caption model loaded."
+ self._prepare_caption()
+ inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device)
+ if not self.config.caption_model_name.startswith('git-'):
+ inputs = inputs.to(self.dtype)
+ tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length)
+ return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip()
def image_to_features(self, image: Image) -> torch.Tensor:
self._prepare_clip()
@@ -237,7 +227,7 @@ class Interrogator():
are less readable."""
caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image)
- merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
+ merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
@@ -254,7 +244,7 @@ class Interrogator():
caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image)
- merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
+ merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt, best_sim = caption, self.similarity(image_features, caption)
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain")
@@ -293,18 +283,18 @@ class Interrogator():
similarity = text_features @ image_features.T
return similarity.T[0].tolist()
- def _prepare_blip(self):
+ def _prepare_caption(self):
if self.config.clip_offload and not self.clip_offloaded:
self.clip_model = self.clip_model.to('cpu')
self.clip_offloaded = True
- if self.blip_offloaded:
- self.blip_model = self.blip_model.to(self.device)
- self.blip_offloaded = False
+ if self.caption_offloaded:
+ self.caption_model = self.caption_model.to(self.device)
+ self.caption_offloaded = False
def _prepare_clip(self):
- if self.config.blip_offload and not self.blip_offloaded:
- self.blip_model = self.blip_model.to('cpu')
- self.blip_offloaded = True
+ if self.config.caption_offload and not self.caption_offloaded:
+ self.caption_model = self.caption_model.to('cpu')
+ self.caption_offloaded = True
if self.clip_offloaded:
self.clip_model = self.clip_model.to(self.device)
self.clip_offloaded = False
@@ -425,8 +415,8 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
progress.update(len(chunk))
progress.close()
-def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
- m = LabelTable([], None, None, None, config)
+def _merge_tables(tables: List[LabelTable], ci: Interrogator) -> LabelTable:
+ m = LabelTable([], None, ci)
for table in tables:
m.labels.extend(table.labels)
m.embeds.extend(table.embeds)
@@ -445,6 +435,12 @@ def _truncate_to_fit(text: str, tokenize) -> str:
new_text += ', ' + part
return new_text
+def list_caption_models() -> List[str]:
+ return list(CAPTION_MODELS.keys())
+
+def list_clip_models() -> List[str]:
+ return ['/'.join(x) for x in open_clip.list_pretrained()]
+
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
"""Load a list of strings from a file."""
if filename is not None:
diff --git a/requirements.txt b/requirements.txt
index 1c73285..d6ff090 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,5 +5,5 @@ requests
safetensors
tqdm
open_clip_torch
-blip-ci
-transformers>=4.15.0,<=4.26.1
+accelerate
+transformers>=4.27.1
\ No newline at end of file
diff --git a/run_cli.py b/run_cli.py
index 59d563d..b1a5ef7 100755
--- a/run_cli.py
+++ b/run_cli.py
@@ -1,12 +1,11 @@
#!/usr/bin/env python3
import argparse
import csv
-import open_clip
import os
import requests
import torch
from PIL import Image
-from clip_interrogator import Interrogator, Config
+from clip_interrogator import Interrogator, Config, list_clip_models
def inference(ci, image, mode):
image = image.convert('RGB')
@@ -36,7 +35,7 @@ def main():
exit(1)
# validate clip model name
- models = ['/'.join(x) for x in open_clip.list_pretrained()]
+ models = list_clip_models()
if args.clip not in models:
print(f"Could not find CLIP model {args.clip}!")
print(f" available models: {models}")
diff --git a/run_gradio.py b/run_gradio.py
index 938171a..0a178bd 100755
--- a/run_gradio.py
+++ b/run_gradio.py
@@ -1,8 +1,7 @@
#!/usr/bin/env python3
import argparse
-import open_clip
import torch
-from clip_interrogator import Config, Interrogator
+from clip_interrogator import Config, Interrogator, list_caption_models, list_clip_models
try:
import gradio as gr
@@ -45,7 +44,11 @@ def image_analysis(image, clip_model_name):
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
-def image_to_prompt(image, mode, clip_model_name):
+def image_to_prompt(image, mode, clip_model_name, blip_model_name):
+ if blip_model_name != ci.config.caption_model_name:
+ ci.config.caption_model_name = blip_model_name
+ ci.load_caption_model()
+
if clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
@@ -60,25 +63,23 @@ def image_to_prompt(image, mode, clip_model_name):
elif mode == 'negative':
return ci.interrogate_negative(image)
-
-models = ['/'.join(x) for x in open_clip.list_pretrained()]
-
def prompt_tab():
with gr.Column():
with gr.Row():
image = gr.Image(type='pil', label="Image")
with gr.Column():
mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
- model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
+ clip_model = gr.Dropdown(list_clip_models(), value=ci.config.clip_model_name, label='CLIP Model')
+ blip_model = gr.Dropdown(list_caption_models(), value=ci.config.caption_model_name, label='Caption Model')
prompt = gr.Textbox(label="Prompt")
button = gr.Button("Generate prompt")
- button.click(image_to_prompt, inputs=[image, mode, model], outputs=prompt)
+ button.click(image_to_prompt, inputs=[image, mode, clip_model, blip_model], outputs=prompt)
def analyze_tab():
with gr.Column():
with gr.Row():
image = gr.Image(type='pil', label="Image")
- model = gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model')
+ model = gr.Dropdown(list_clip_models(), value='ViT-L-14/openai', label='CLIP Model')
with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)
diff --git a/setup.py b/setup.py
index f2806e0..a6db97a 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup(
name="clip-interrogator",
- version="0.5.5",
+ version="0.6.0",
license='MIT',
author='pharmapsychotic',
author_email='me@pharmapsychotic.com',