diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..984cf86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+*.pyc
+.vscode/
+cache/
+venv/
diff --git a/README.md b/README.md
index ba8e3bd..72ef8a2 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,19 @@
# clip-interrogator
-[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) Version 2
+Run Version 2 on Colab, HuggingFace, and Replicate!
-[![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) Version 2
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator)
-[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) Version 1
-The CLIP Interrogator uses the OpenAI CLIP models to test a given image against a variety of artists, mediums, and styles to study how the different models see the content of the image. It also combines the results with BLIP caption to suggest a text prompt to create more images similar to what was given.
+
+Version 1 still available in Colab for comparing different CLIP models
+
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb)
+
+
+
+
+*Want to figure out what a good prompt might be to create new images like an existing one? The **CLIP Interrogator** is here to get you answers!*
+
+The **CLIP Interrogator** is a prompt engineering tool that combines OpenAI's [CLIP](https://openai.com/blog/clip/) and Salesforce's [BLIP](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/) to optimize text prompts to match a given image. Use the resulting prompts with text-to-image models like Stable Diffusion.
diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py
new file mode 100644
index 0000000..4f47c51
--- /dev/null
+++ b/clip_interrogator/__init__.py
@@ -0,0 +1 @@
+from .interrogate import CLIPInterrogator, Config, LabelTable
\ No newline at end of file
diff --git a/clip_interrogator/interrogate.py b/clip_interrogator/interrogate.py
new file mode 100644
index 0000000..a61b847
--- /dev/null
+++ b/clip_interrogator/interrogate.py
@@ -0,0 +1,260 @@
+import clip
+import hashlib
+import inspect
+import math
+import numpy as np
+import os
+import pickle
+import torch
+
+from dataclasses import dataclass
+from models.blip import blip_decoder
+from PIL import Image
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from tqdm import tqdm
+from typing import List
+
+
+@dataclass
+class Config:
+ # models can optionally be passed in directly
+ blip_model = None
+ clip_model = None
+ clip_preprocess = None
+
+ # blip settings
+ blip_image_eval_size: int = 384
+ blip_max_length: int = 20
+ blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
+ blip_num_beams: int = 3
+
+ # clip settings
+ clip_model_name: str = 'ViT-L/14'
+
+ # interrogator settings
+ cache_path: str = 'cache'
+ chunk_size: int = 2048
+ data_path: str = 'data'
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
+ flavor_intermediate_count: int = 2048
+
+
+def _load_list(data_path, filename) -> List[str]:
+ with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
+ items = [line.strip() for line in f.readlines()]
+ return items
+
+
+class CLIPInterrogator():
+ def __init__(self, config: Config):
+ self.config = config
+ self.device = config.device
+
+ if config.blip_model is None:
+ print("Loading BLIP model...")
+ blip_path = os.path.dirname(inspect.getfile(blip_decoder))
+ configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
+ med_config = os.path.join(configs_path, 'med_config.json')
+ blip_model = blip_decoder(
+ pretrained=config.blip_model_url,
+ image_size=config.blip_image_eval_size,
+ vit='large',
+ med_config=med_config
+ )
+ blip_model.eval()
+ blip_model = blip_model.to(config.device)
+ self.blip_model = blip_model
+ else:
+ self.blip_model = config.blip_model
+
+ if config.clip_model is None:
+ print("Loading CLIP model...")
+ self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device)
+ self.clip_model.to(config.device).eval()
+ else:
+ self.clip_model = config.clip_model
+ self.clip_preprocess = config.clip_preprocess
+
+ sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
+ trending_list = [site for site in sites]
+ trending_list.extend(["trending on "+site for site in sites])
+ trending_list.extend(["featured on "+site for site in sites])
+ trending_list.extend([site+" contest winner" for site in sites])
+
+ raw_artists = _load_list(config.data_path, 'artists.txt')
+ artists = [f"by {a}" for a in raw_artists]
+ artists.extend([f"inspired by {a}" for a in raw_artists])
+
+ self.artists = LabelTable(artists, "artists", self.clip_model, config)
+ self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, config)
+ self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, config)
+ self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, config)
+ self.trendings = LabelTable(trending_list, "trendings", self.clip_model, config)
+
+ def generate_caption(self, pil_image: Image) -> str:
+ size = self.config.blip_image_eval_size
+ gpu_image = transforms.Compose([
+ transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+ ])(pil_image).unsqueeze(0).to(self.device)
+
+ with torch.no_grad():
+ caption = self.blip_model.generate(
+ gpu_image,
+ sample=False,
+ num_beams=self.config.blip_num_beams,
+ max_length=self.config.blip_max_length,
+ min_length=5
+ )
+ return caption[0]
+
+ def interrogate(self, image: Image) -> str:
+ caption = self.generate_caption(image)
+
+ images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
+ with torch.no_grad():
+ image_features = self.clip_model.encode_image(images).float()
+ image_features /= image_features.norm(dim=-1, keepdim=True)
+
+ flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
+ best_medium = self.mediums.rank(image_features, 1)[0]
+ best_artist = self.artists.rank(image_features, 1)[0]
+ best_trending = self.trendings.rank(image_features, 1)[0]
+ best_movement = self.movements.rank(image_features, 1)[0]
+
+ best_prompt = caption
+ best_sim = self.similarity(image_features, best_prompt)
+
+ def check(addition: str) -> bool:
+ nonlocal best_prompt, best_sim
+ prompt = best_prompt + ", " + addition
+ sim = self.similarity(image_features, prompt)
+ if sim > best_sim:
+ best_sim = sim
+ best_prompt = prompt
+ return True
+ return False
+
+ def check_multi_batch(opts: List[str]):
+ nonlocal best_prompt, best_sim
+ prompts = []
+ for i in range(2**len(opts)):
+ prompt = best_prompt
+ for bit in range(len(opts)):
+ if i & (1 << bit):
+ prompt += ", " + opts[bit]
+ prompts.append(prompt)
+
+ t = LabelTable(prompts, None, self.clip_model, self.config)
+ best_prompt = t.rank(image_features, 1)[0]
+ best_sim = self.similarity(image_features, best_prompt)
+
+ check_multi_batch([best_medium, best_artist, best_trending, best_movement])
+
+ extended_flavors = set(flaves)
+ for _ in tqdm(range(25), desc="Flavor chain"):
+ try:
+ best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
+ flave = best[len(best_prompt)+2:]
+ if not check(flave):
+ break
+ extended_flavors.remove(flave)
+ except:
+ # exceeded max prompt length
+ break
+
+ return best_prompt
+
+ def rank_top(self, image_features, text_array: List[str]) -> str:
+ text_tokens = clip.tokenize([text for text in text_array]).to(self.device)
+ with torch.no_grad():
+ text_features = self.clip_model.encode_text(text_tokens).float()
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+
+ similarity = torch.zeros((1, len(text_array)), device=self.device)
+ for i in range(image_features.shape[0]):
+ similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
+
+ _, top_labels = similarity.cpu().topk(1, dim=-1)
+ return text_array[top_labels[0][0].numpy()]
+
+ def similarity(self, image_features, text) -> np.float32:
+ text_tokens = clip.tokenize([text]).to(self.device)
+ with torch.no_grad():
+ text_features = self.clip_model.encode_text(text_tokens).float()
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+ similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
+ return similarity[0][0]
+
+
+class LabelTable():
+ def __init__(self, labels:List[str], desc:str, clip_model, config: Config):
+ self.chunk_size = config.chunk_size
+ self.device = config.device
+ self.labels = labels
+ self.embeds = []
+
+ hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
+
+ cache_filepath = None
+ if config.cache_path is not None and desc is not None:
+ os.makedirs(config.cache_path, exist_ok=True)
+ sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_')
+ cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl")
+ if desc is not None and os.path.exists(cache_filepath):
+ with open(cache_filepath, 'rb') as f:
+ data = pickle.load(f)
+ if data.get('hash') == hash:
+ self.labels = data['labels']
+ self.embeds = data['embeds']
+
+ if len(self.labels) != len(self.embeds):
+ self.embeds = []
+ chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size))
+ for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
+ text_tokens = clip.tokenize(chunk).to(self.device)
+ with torch.no_grad():
+ text_features = clip_model.encode_text(text_tokens).float()
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+ text_features = text_features.half().cpu().numpy()
+ for i in range(text_features.shape[0]):
+ self.embeds.append(text_features[i])
+
+ if cache_filepath is not None:
+ with open(cache_filepath, 'wb') as f:
+ pickle.dump({
+ "labels": self.labels,
+ "embeds": self.embeds,
+ "hash": hash,
+ "model": config.clip_model_name
+ }, f)
+
+ def _rank(self, image_features, text_embeds, top_count=1):
+ top_count = min(top_count, len(text_embeds))
+ similarity = torch.zeros((1, len(text_embeds))).to(self.device)
+ text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(self.device)
+ for i in range(image_features.shape[0]):
+ similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1)
+ _, top_labels = similarity.cpu().topk(top_count, dim=-1)
+ return [top_labels[0][i].numpy() for i in range(top_count)]
+
+ def rank(self, image_features, top_count=1) -> List[str]:
+ if len(self.labels) <= self.chunk_size:
+ tops = self._rank(image_features, self.embeds, top_count=top_count)
+ return [self.labels[i] for i in tops]
+
+ num_chunks = int(math.ceil(len(self.labels)/self.chunk_size))
+ keep_per_chunk = int(self.chunk_size / num_chunks)
+
+ top_labels, top_embeds = [], []
+ for chunk_idx in tqdm(range(num_chunks)):
+ start = chunk_idx*self.chunk_size
+ stop = min(start+self.chunk_size, len(self.embeds))
+ tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
+ top_labels.extend([self.labels[start+i] for i in tops])
+ top_embeds.extend([self.embeds[start+i] for i in tops])
+
+ tops = self._rank(image_features, top_embeds, top_count=top_count)
+ return [top_labels[i] for i in tops]
diff --git a/main.py b/main.py
new file mode 100755
index 0000000..0ae9392
--- /dev/null
+++ b/main.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python3
+
+import argparse
+import clip
+import requests
+import torch
+
+from PIL import Image
+
+from clip_interrogator import CLIPInterrogator, Config
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-i', '--image', help='image file or url')
+ parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use')
+
+ args = parser.parse_args()
+ if not args.image:
+ parser.print_help()
+ exit(1)
+
+ # load image
+ image_path = args.image
+ if str(image_path).startswith('http://') or str(image_path).startswith('https://'):
+ image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
+ else:
+ image = Image.open(image_path).convert('RGB')
+ if not image:
+ print(f'Error opening image {image_path}')
+ exit(1)
+
+ # validate clip model name
+ if args.clip not in clip.available_models():
+ print(f"Could not find CLIP model {args.clip}!")
+ print(f" available models: {clip.available_models()}")
+ exit(1)
+
+ # generate a nice prompt
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ config = Config(device=device, clip_model_name=args.clip, data_path='data')
+ interrogator = CLIPInterrogator(config)
+ prompt = interrogator.interrogate(image)
+ print(prompt)
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..8fb214e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,7 @@
+torch
+torchvision
+Pillow
+requests
+tqdm
+-e git+https://github.com/openai/CLIP.git@main#egg=clip
+-e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip