6 changed files with 332 additions and 4 deletions
@ -1,10 +1,19 @@
|
||||
# clip-interrogator |
||||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) Version 2 |
||||
Run Version 2 on Colab, HuggingFace, and Replicate! |
||||
|
||||
[![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) Version 2 |
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator) |
||||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) Version 1 |
||||
|
||||
The CLIP Interrogator uses the OpenAI CLIP models to test a given image against a variety of artists, mediums, and styles to study how the different models see the content of the image. It also combines the results with BLIP caption to suggest a text prompt to create more images similar to what was given. |
||||
<br> |
||||
|
||||
Version 1 still available in Colab for comparing different CLIP models |
||||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) |
||||
|
||||
|
||||
<br> |
||||
|
||||
*Want to figure out what a good prompt might be to create new images like an existing one? The **CLIP Interrogator** is here to get you answers!* |
||||
|
||||
The **CLIP Interrogator** is a prompt engineering tool that combines OpenAI's [CLIP](https://openai.com/blog/clip/) and Salesforce's [BLIP](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/) to optimize text prompts to match a given image. Use the resulting prompts with text-to-image models like Stable Diffusion. |
||||
|
@ -0,0 +1 @@
|
||||
from .interrogate import CLIPInterrogator, Config, LabelTable |
@ -0,0 +1,260 @@
|
||||
import clip |
||||
import hashlib |
||||
import inspect |
||||
import math |
||||
import numpy as np |
||||
import os |
||||
import pickle |
||||
import torch |
||||
|
||||
from dataclasses import dataclass |
||||
from models.blip import blip_decoder |
||||
from PIL import Image |
||||
from torchvision import transforms |
||||
from torchvision.transforms.functional import InterpolationMode |
||||
from tqdm import tqdm |
||||
from typing import List |
||||
|
||||
|
||||
@dataclass |
||||
class Config: |
||||
# models can optionally be passed in directly |
||||
blip_model = None |
||||
clip_model = None |
||||
clip_preprocess = None |
||||
|
||||
# blip settings |
||||
blip_image_eval_size: int = 384 |
||||
blip_max_length: int = 20 |
||||
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
||||
blip_num_beams: int = 3 |
||||
|
||||
# clip settings |
||||
clip_model_name: str = 'ViT-L/14' |
||||
|
||||
# interrogator settings |
||||
cache_path: str = 'cache' |
||||
chunk_size: int = 2048 |
||||
data_path: str = 'data' |
||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' |
||||
flavor_intermediate_count: int = 2048 |
||||
|
||||
|
||||
def _load_list(data_path, filename) -> List[str]: |
||||
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: |
||||
items = [line.strip() for line in f.readlines()] |
||||
return items |
||||
|
||||
|
||||
class CLIPInterrogator(): |
||||
def __init__(self, config: Config): |
||||
self.config = config |
||||
self.device = config.device |
||||
|
||||
if config.blip_model is None: |
||||
print("Loading BLIP model...") |
||||
blip_path = os.path.dirname(inspect.getfile(blip_decoder)) |
||||
configs_path = os.path.join(os.path.dirname(blip_path), 'configs') |
||||
med_config = os.path.join(configs_path, 'med_config.json') |
||||
blip_model = blip_decoder( |
||||
pretrained=config.blip_model_url, |
||||
image_size=config.blip_image_eval_size, |
||||
vit='large', |
||||
med_config=med_config |
||||
) |
||||
blip_model.eval() |
||||
blip_model = blip_model.to(config.device) |
||||
self.blip_model = blip_model |
||||
else: |
||||
self.blip_model = config.blip_model |
||||
|
||||
if config.clip_model is None: |
||||
print("Loading CLIP model...") |
||||
self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device) |
||||
self.clip_model.to(config.device).eval() |
||||
else: |
||||
self.clip_model = config.clip_model |
||||
self.clip_preprocess = config.clip_preprocess |
||||
|
||||
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] |
||||
trending_list = [site for site in sites] |
||||
trending_list.extend(["trending on "+site for site in sites]) |
||||
trending_list.extend(["featured on "+site for site in sites]) |
||||
trending_list.extend([site+" contest winner" for site in sites]) |
||||
|
||||
raw_artists = _load_list(config.data_path, 'artists.txt') |
||||
artists = [f"by {a}" for a in raw_artists] |
||||
artists.extend([f"inspired by {a}" for a in raw_artists]) |
||||
|
||||
self.artists = LabelTable(artists, "artists", self.clip_model, config) |
||||
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, config) |
||||
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, config) |
||||
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, config) |
||||
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, config) |
||||
|
||||
def generate_caption(self, pil_image: Image) -> str: |
||||
size = self.config.blip_image_eval_size |
||||
gpu_image = transforms.Compose([ |
||||
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), |
||||
transforms.ToTensor(), |
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
||||
])(pil_image).unsqueeze(0).to(self.device) |
||||
|
||||
with torch.no_grad(): |
||||
caption = self.blip_model.generate( |
||||
gpu_image, |
||||
sample=False, |
||||
num_beams=self.config.blip_num_beams, |
||||
max_length=self.config.blip_max_length, |
||||
min_length=5 |
||||
) |
||||
return caption[0] |
||||
|
||||
def interrogate(self, image: Image) -> str: |
||||
caption = self.generate_caption(image) |
||||
|
||||
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
||||
with torch.no_grad(): |
||||
image_features = self.clip_model.encode_image(images).float() |
||||
image_features /= image_features.norm(dim=-1, keepdim=True) |
||||
|
||||
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count) |
||||
best_medium = self.mediums.rank(image_features, 1)[0] |
||||
best_artist = self.artists.rank(image_features, 1)[0] |
||||
best_trending = self.trendings.rank(image_features, 1)[0] |
||||
best_movement = self.movements.rank(image_features, 1)[0] |
||||
|
||||
best_prompt = caption |
||||
best_sim = self.similarity(image_features, best_prompt) |
||||
|
||||
def check(addition: str) -> bool: |
||||
nonlocal best_prompt, best_sim |
||||
prompt = best_prompt + ", " + addition |
||||
sim = self.similarity(image_features, prompt) |
||||
if sim > best_sim: |
||||
best_sim = sim |
||||
best_prompt = prompt |
||||
return True |
||||
return False |
||||
|
||||
def check_multi_batch(opts: List[str]): |
||||
nonlocal best_prompt, best_sim |
||||
prompts = [] |
||||
for i in range(2**len(opts)): |
||||
prompt = best_prompt |
||||
for bit in range(len(opts)): |
||||
if i & (1 << bit): |
||||
prompt += ", " + opts[bit] |
||||
prompts.append(prompt) |
||||
|
||||
t = LabelTable(prompts, None, self.clip_model, self.config) |
||||
best_prompt = t.rank(image_features, 1)[0] |
||||
best_sim = self.similarity(image_features, best_prompt) |
||||
|
||||
check_multi_batch([best_medium, best_artist, best_trending, best_movement]) |
||||
|
||||
extended_flavors = set(flaves) |
||||
for _ in tqdm(range(25), desc="Flavor chain"): |
||||
try: |
||||
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) |
||||
flave = best[len(best_prompt)+2:] |
||||
if not check(flave): |
||||
break |
||||
extended_flavors.remove(flave) |
||||
except: |
||||
# exceeded max prompt length |
||||
break |
||||
|
||||
return best_prompt |
||||
|
||||
def rank_top(self, image_features, text_array: List[str]) -> str: |
||||
text_tokens = clip.tokenize([text for text in text_array]).to(self.device) |
||||
with torch.no_grad(): |
||||
text_features = self.clip_model.encode_text(text_tokens).float() |
||||
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||
|
||||
similarity = torch.zeros((1, len(text_array)), device=self.device) |
||||
for i in range(image_features.shape[0]): |
||||
similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) |
||||
|
||||
_, top_labels = similarity.cpu().topk(1, dim=-1) |
||||
return text_array[top_labels[0][0].numpy()] |
||||
|
||||
def similarity(self, image_features, text) -> np.float32: |
||||
text_tokens = clip.tokenize([text]).to(self.device) |
||||
with torch.no_grad(): |
||||
text_features = self.clip_model.encode_text(text_tokens).float() |
||||
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T |
||||
return similarity[0][0] |
||||
|
||||
|
||||
class LabelTable(): |
||||
def __init__(self, labels:List[str], desc:str, clip_model, config: Config): |
||||
self.chunk_size = config.chunk_size |
||||
self.device = config.device |
||||
self.labels = labels |
||||
self.embeds = [] |
||||
|
||||
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() |
||||
|
||||
cache_filepath = None |
||||
if config.cache_path is not None and desc is not None: |
||||
os.makedirs(config.cache_path, exist_ok=True) |
||||
sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_') |
||||
cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl") |
||||
if desc is not None and os.path.exists(cache_filepath): |
||||
with open(cache_filepath, 'rb') as f: |
||||
data = pickle.load(f) |
||||
if data.get('hash') == hash: |
||||
self.labels = data['labels'] |
||||
self.embeds = data['embeds'] |
||||
|
||||
if len(self.labels) != len(self.embeds): |
||||
self.embeds = [] |
||||
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) |
||||
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None): |
||||
text_tokens = clip.tokenize(chunk).to(self.device) |
||||
with torch.no_grad(): |
||||
text_features = clip_model.encode_text(text_tokens).float() |
||||
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||
text_features = text_features.half().cpu().numpy() |
||||
for i in range(text_features.shape[0]): |
||||
self.embeds.append(text_features[i]) |
||||
|
||||
if cache_filepath is not None: |
||||
with open(cache_filepath, 'wb') as f: |
||||
pickle.dump({ |
||||
"labels": self.labels, |
||||
"embeds": self.embeds, |
||||
"hash": hash, |
||||
"model": config.clip_model_name |
||||
}, f) |
||||
|
||||
def _rank(self, image_features, text_embeds, top_count=1): |
||||
top_count = min(top_count, len(text_embeds)) |
||||
similarity = torch.zeros((1, len(text_embeds))).to(self.device) |
||||
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(self.device) |
||||
for i in range(image_features.shape[0]): |
||||
similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1) |
||||
_, top_labels = similarity.cpu().topk(top_count, dim=-1) |
||||
return [top_labels[0][i].numpy() for i in range(top_count)] |
||||
|
||||
def rank(self, image_features, top_count=1) -> List[str]: |
||||
if len(self.labels) <= self.chunk_size: |
||||
tops = self._rank(image_features, self.embeds, top_count=top_count) |
||||
return [self.labels[i] for i in tops] |
||||
|
||||
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size)) |
||||
keep_per_chunk = int(self.chunk_size / num_chunks) |
||||
|
||||
top_labels, top_embeds = [], [] |
||||
for chunk_idx in tqdm(range(num_chunks)): |
||||
start = chunk_idx*self.chunk_size |
||||
stop = min(start+self.chunk_size, len(self.embeds)) |
||||
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) |
||||
top_labels.extend([self.labels[start+i] for i in tops]) |
||||
top_embeds.extend([self.embeds[start+i] for i in tops]) |
||||
|
||||
tops = self._rank(image_features, top_embeds, top_count=top_count) |
||||
return [top_labels[i] for i in tops] |
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3 |
||||
|
||||
import argparse |
||||
import clip |
||||
import requests |
||||
import torch |
||||
|
||||
from PIL import Image |
||||
|
||||
from clip_interrogator import CLIPInterrogator, Config |
||||
|
||||
|
||||
def main(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('-i', '--image', help='image file or url') |
||||
parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use') |
||||
|
||||
args = parser.parse_args() |
||||
if not args.image: |
||||
parser.print_help() |
||||
exit(1) |
||||
|
||||
# load image |
||||
image_path = args.image |
||||
if str(image_path).startswith('http://') or str(image_path).startswith('https://'): |
||||
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB') |
||||
else: |
||||
image = Image.open(image_path).convert('RGB') |
||||
if not image: |
||||
print(f'Error opening image {image_path}') |
||||
exit(1) |
||||
|
||||
# validate clip model name |
||||
if args.clip not in clip.available_models(): |
||||
print(f"Could not find CLIP model {args.clip}!") |
||||
print(f" available models: {clip.available_models()}") |
||||
exit(1) |
||||
|
||||
# generate a nice prompt |
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||
config = Config(device=device, clip_model_name=args.clip, data_path='data') |
||||
interrogator = CLIPInterrogator(config) |
||||
prompt = interrogator.interrogate(image) |
||||
print(prompt) |
||||
|
||||
if __name__ == "__main__": |
||||
main() |
Loading…
Reference in new issue