diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..984cf86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.pyc +.vscode/ +cache/ +venv/ diff --git a/README.md b/README.md index ba8e3bd..72ef8a2 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,19 @@ # clip-interrogator -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) Version 2 +Run Version 2 on Colab, HuggingFace, and Replicate! -[![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) Version 2 +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator) -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) Version 1 -The CLIP Interrogator uses the OpenAI CLIP models to test a given image against a variety of artists, mediums, and styles to study how the different models see the content of the image. It also combines the results with BLIP caption to suggest a text prompt to create more images similar to what was given. +
+Version 1 still available in Colab for comparing different CLIP models + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) + + +
+ +*Want to figure out what a good prompt might be to create new images like an existing one? The **CLIP Interrogator** is here to get you answers!* + +The **CLIP Interrogator** is a prompt engineering tool that combines OpenAI's [CLIP](https://openai.com/blog/clip/) and Salesforce's [BLIP](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/) to optimize text prompts to match a given image. Use the resulting prompts with text-to-image models like Stable Diffusion. diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py new file mode 100644 index 0000000..4f47c51 --- /dev/null +++ b/clip_interrogator/__init__.py @@ -0,0 +1 @@ +from .interrogate import CLIPInterrogator, Config, LabelTable \ No newline at end of file diff --git a/clip_interrogator/interrogate.py b/clip_interrogator/interrogate.py new file mode 100644 index 0000000..a61b847 --- /dev/null +++ b/clip_interrogator/interrogate.py @@ -0,0 +1,260 @@ +import clip +import hashlib +import inspect +import math +import numpy as np +import os +import pickle +import torch + +from dataclasses import dataclass +from models.blip import blip_decoder +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from tqdm import tqdm +from typing import List + + +@dataclass +class Config: + # models can optionally be passed in directly + blip_model = None + clip_model = None + clip_preprocess = None + + # blip settings + blip_image_eval_size: int = 384 + blip_max_length: int = 20 + blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' + blip_num_beams: int = 3 + + # clip settings + clip_model_name: str = 'ViT-L/14' + + # interrogator settings + cache_path: str = 'cache' + chunk_size: int = 2048 + data_path: str = 'data' + device: str = 'cuda' if torch.cuda.is_available() else 'cpu' + flavor_intermediate_count: int = 2048 + + +def _load_list(data_path, filename) -> List[str]: + with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: + items = [line.strip() for line in f.readlines()] + return items + + +class CLIPInterrogator(): + def __init__(self, config: Config): + self.config = config + self.device = config.device + + if config.blip_model is None: + print("Loading BLIP model...") + blip_path = os.path.dirname(inspect.getfile(blip_decoder)) + configs_path = os.path.join(os.path.dirname(blip_path), 'configs') + med_config = os.path.join(configs_path, 'med_config.json') + blip_model = blip_decoder( + pretrained=config.blip_model_url, + image_size=config.blip_image_eval_size, + vit='large', + med_config=med_config + ) + blip_model.eval() + blip_model = blip_model.to(config.device) + self.blip_model = blip_model + else: + self.blip_model = config.blip_model + + if config.clip_model is None: + print("Loading CLIP model...") + self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device) + self.clip_model.to(config.device).eval() + else: + self.clip_model = config.clip_model + self.clip_preprocess = config.clip_preprocess + + sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] + trending_list = [site for site in sites] + trending_list.extend(["trending on "+site for site in sites]) + trending_list.extend(["featured on "+site for site in sites]) + trending_list.extend([site+" contest winner" for site in sites]) + + raw_artists = _load_list(config.data_path, 'artists.txt') + artists = [f"by {a}" for a in raw_artists] + artists.extend([f"inspired by {a}" for a in raw_artists]) + + self.artists = LabelTable(artists, "artists", self.clip_model, config) + self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, config) + self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, config) + self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, config) + self.trendings = LabelTable(trending_list, "trendings", self.clip_model, config) + + def generate_caption(self, pil_image: Image) -> str: + size = self.config.blip_image_eval_size + gpu_image = transforms.Compose([ + transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ])(pil_image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + caption = self.blip_model.generate( + gpu_image, + sample=False, + num_beams=self.config.blip_num_beams, + max_length=self.config.blip_max_length, + min_length=5 + ) + return caption[0] + + def interrogate(self, image: Image) -> str: + caption = self.generate_caption(image) + + images = self.clip_preprocess(image).unsqueeze(0).to(self.device) + with torch.no_grad(): + image_features = self.clip_model.encode_image(images).float() + image_features /= image_features.norm(dim=-1, keepdim=True) + + flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count) + best_medium = self.mediums.rank(image_features, 1)[0] + best_artist = self.artists.rank(image_features, 1)[0] + best_trending = self.trendings.rank(image_features, 1)[0] + best_movement = self.movements.rank(image_features, 1)[0] + + best_prompt = caption + best_sim = self.similarity(image_features, best_prompt) + + def check(addition: str) -> bool: + nonlocal best_prompt, best_sim + prompt = best_prompt + ", " + addition + sim = self.similarity(image_features, prompt) + if sim > best_sim: + best_sim = sim + best_prompt = prompt + return True + return False + + def check_multi_batch(opts: List[str]): + nonlocal best_prompt, best_sim + prompts = [] + for i in range(2**len(opts)): + prompt = best_prompt + for bit in range(len(opts)): + if i & (1 << bit): + prompt += ", " + opts[bit] + prompts.append(prompt) + + t = LabelTable(prompts, None, self.clip_model, self.config) + best_prompt = t.rank(image_features, 1)[0] + best_sim = self.similarity(image_features, best_prompt) + + check_multi_batch([best_medium, best_artist, best_trending, best_movement]) + + extended_flavors = set(flaves) + for _ in tqdm(range(25), desc="Flavor chain"): + try: + best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) + flave = best[len(best_prompt)+2:] + if not check(flave): + break + extended_flavors.remove(flave) + except: + # exceeded max prompt length + break + + return best_prompt + + def rank_top(self, image_features, text_array: List[str]) -> str: + text_tokens = clip.tokenize([text for text in text_array]).to(self.device) + with torch.no_grad(): + text_features = self.clip_model.encode_text(text_tokens).float() + text_features /= text_features.norm(dim=-1, keepdim=True) + + similarity = torch.zeros((1, len(text_array)), device=self.device) + for i in range(image_features.shape[0]): + similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) + + _, top_labels = similarity.cpu().topk(1, dim=-1) + return text_array[top_labels[0][0].numpy()] + + def similarity(self, image_features, text) -> np.float32: + text_tokens = clip.tokenize([text]).to(self.device) + with torch.no_grad(): + text_features = self.clip_model.encode_text(text_tokens).float() + text_features /= text_features.norm(dim=-1, keepdim=True) + similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T + return similarity[0][0] + + +class LabelTable(): + def __init__(self, labels:List[str], desc:str, clip_model, config: Config): + self.chunk_size = config.chunk_size + self.device = config.device + self.labels = labels + self.embeds = [] + + hash = hashlib.sha256(",".join(labels).encode()).hexdigest() + + cache_filepath = None + if config.cache_path is not None and desc is not None: + os.makedirs(config.cache_path, exist_ok=True) + sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_') + cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl") + if desc is not None and os.path.exists(cache_filepath): + with open(cache_filepath, 'rb') as f: + data = pickle.load(f) + if data.get('hash') == hash: + self.labels = data['labels'] + self.embeds = data['embeds'] + + if len(self.labels) != len(self.embeds): + self.embeds = [] + chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) + for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None): + text_tokens = clip.tokenize(chunk).to(self.device) + with torch.no_grad(): + text_features = clip_model.encode_text(text_tokens).float() + text_features /= text_features.norm(dim=-1, keepdim=True) + text_features = text_features.half().cpu().numpy() + for i in range(text_features.shape[0]): + self.embeds.append(text_features[i]) + + if cache_filepath is not None: + with open(cache_filepath, 'wb') as f: + pickle.dump({ + "labels": self.labels, + "embeds": self.embeds, + "hash": hash, + "model": config.clip_model_name + }, f) + + def _rank(self, image_features, text_embeds, top_count=1): + top_count = min(top_count, len(text_embeds)) + similarity = torch.zeros((1, len(text_embeds))).to(self.device) + text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(self.device) + for i in range(image_features.shape[0]): + similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1) + _, top_labels = similarity.cpu().topk(top_count, dim=-1) + return [top_labels[0][i].numpy() for i in range(top_count)] + + def rank(self, image_features, top_count=1) -> List[str]: + if len(self.labels) <= self.chunk_size: + tops = self._rank(image_features, self.embeds, top_count=top_count) + return [self.labels[i] for i in tops] + + num_chunks = int(math.ceil(len(self.labels)/self.chunk_size)) + keep_per_chunk = int(self.chunk_size / num_chunks) + + top_labels, top_embeds = [], [] + for chunk_idx in tqdm(range(num_chunks)): + start = chunk_idx*self.chunk_size + stop = min(start+self.chunk_size, len(self.embeds)) + tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) + top_labels.extend([self.labels[start+i] for i in tops]) + top_embeds.extend([self.embeds[start+i] for i in tops]) + + tops = self._rank(image_features, top_embeds, top_count=top_count) + return [top_labels[i] for i in tops] diff --git a/main.py b/main.py new file mode 100755 index 0000000..0ae9392 --- /dev/null +++ b/main.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +import argparse +import clip +import requests +import torch + +from PIL import Image + +from clip_interrogator import CLIPInterrogator, Config + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--image', help='image file or url') + parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use') + + args = parser.parse_args() + if not args.image: + parser.print_help() + exit(1) + + # load image + image_path = args.image + if str(image_path).startswith('http://') or str(image_path).startswith('https://'): + image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB') + else: + image = Image.open(image_path).convert('RGB') + if not image: + print(f'Error opening image {image_path}') + exit(1) + + # validate clip model name + if args.clip not in clip.available_models(): + print(f"Could not find CLIP model {args.clip}!") + print(f" available models: {clip.available_models()}") + exit(1) + + # generate a nice prompt + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + config = Config(device=device, clip_model_name=args.clip, data_path='data') + interrogator = CLIPInterrogator(config) + prompt = interrogator.interrogate(image) + print(prompt) + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8fb214e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +torch +torchvision +Pillow +requests +tqdm +-e git+https://github.com/openai/CLIP.git@main#egg=clip +-e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip