From 55b1770386f8c25778702532def9429d2f8bc07c Mon Sep 17 00:00:00 2001 From: pharmapsychotic Date: Sun, 13 Nov 2022 16:14:46 -0600 Subject: [PATCH] Update Replicate cog to use clip_interrogator library --- .gitignore | 1 + README.md | 3 - cog.yaml | 10 +- predict.py | 335 +---------------------------------------------------- 4 files changed, 8 insertions(+), 341 deletions(-) diff --git a/.gitignore b/.gitignore index b55406c..b5d2001 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pyc +.cog/ .vscode/ cache/ clip-interrogator/ diff --git a/README.md b/README.md index b24e236..86e0502 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,6 @@ Run Version 2 on Colab, HuggingFace, and Replicate! [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator) -[![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator) - -
Version 1 still available in Colab for comparing different CLIP models diff --git a/cog.yaml b/cog.yaml index f9aed3f..c9a4bde 100644 --- a/cog.yaml +++ b/cog.yaml @@ -13,16 +13,8 @@ build: - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" run: - - pip install -e git+https://github.com/pharmapsychotic/BLIP.git@main#egg=blip + - pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip - pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50.pt" "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt" - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN101.pt" "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt" - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x4.pt" "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt" - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x16.pt" "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt" - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x64.pt" "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt" - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-32.pt" "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-16.pt" "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt" - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt" - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14-336px.pt" "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index 3a57625..c085797 100644 --- a/predict.py +++ b/predict.py @@ -1,341 +1,18 @@ import sys - -sys.path.append("src/clip") -sys.path.append("src/blip") - -import os -import hashlib -import math -import numpy as np -import pickle -from tqdm import tqdm from PIL import Image -import torch -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode -import clip -from models.blip import blip_decoder from cog import BasePredictor, Input, Path +sys.path.extend(["src/clip", "src/blip"]) -DATA_PATH = "data" -chunk_size = 2048 -flavor_intermediate_count = 2048 -blip_image_eval_size = 384 +from clip_interrogator import Interrogator, Config class Predictor(BasePredictor): def setup(self): - """Load the model into memory to make running multiple predictions efficient""" - - self.device = "cuda:0" - - print("Loading BLIP model...") - self.blip_model = blip_decoder( - pretrained="weights/model_large_caption.pth", # downloaded with wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth - image_size=blip_image_eval_size, - vit="large", - med_config="src/blip/configs/med_config.json", - ) - self.blip_model.eval() - self.blip_model = self.blip_model.to(self.device) - - print("Loading CLIP model...") - self.clip_models, self.clip_preprocess = {}, {} - for clip_model_name in [ - "ViT-B/32", - "ViT-B/16", - "ViT-L/14", - "ViT-L/14@336px", - "RN101", - "RN50", - "RN50x4", - "RN50x16", - "RN50x64", - ]: - ( - self.clip_models[clip_model_name], - self.clip_preprocess[clip_model_name], - ) = clip.load(clip_model_name, device=self.device) - self.clip_models[clip_model_name].cuda().eval() - - sites = [ - "Artstation", - "behance", - "cg society", - "cgsociety", - "deviantart", - "dribble", - "flickr", - "instagram", - "pexels", - "pinterest", - "pixabay", - "pixiv", - "polycount", - "reddit", - "shutterstock", - "tumblr", - "unsplash", - "zbrush central", - ] - self.trending_list = [site for site in sites] - self.trending_list.extend(["trending on " + site for site in sites]) - self.trending_list.extend(["featured on " + site for site in sites]) - self.trending_list.extend([site + " contest winner" for site in sites]) - raw_artists = load_list(f"{DATA_PATH}/artists.txt") - self.artists = [f"by {a}" for a in raw_artists] - self.artists.extend([f"inspired by {a}" for a in raw_artists]) + config = Config(device="cuda:0", clip_model_name='ViT-L/14') + self.ci = Interrogator(config) - def predict( - self, - image: Path = Input(description="Input image"), - clip_model_name: str = Input( - default="ViT-L/14", - choices=[ - "ViT-B/32", - "ViT-B/16", - "ViT-L/14", - "ViT-L/14@336px", - "RN101", - "RN50", - "RN50x4", - "RN50x16", - "RN50x64", - ], - description="Choose a clip model.", - ), - ) -> str: + def predict(self, image: Path = Input(description="Input image")) -> str: """Run a single prediction on the model""" - clip_model = self.clip_models[clip_model_name] - clip_preprocess = self.clip_preprocess[clip_model_name] - - artists = LabelTable(self.artists, "artists", clip_model_name, clip_model) - flavors = LabelTable( - load_list(f"{DATA_PATH}/flavors.txt"), - "flavors", - clip_model_name, - clip_model, - ) - mediums = LabelTable( - load_list(f"{DATA_PATH}/mediums.txt"), - "mediums", - clip_model_name, - clip_model, - ) - movements = LabelTable( - load_list(f"{DATA_PATH}/movements.txt"), - "movements", - clip_model_name, - clip_model, - ) - trendings = LabelTable( - self.trending_list, "trendings", clip_model_name, clip_model - ) - image = Image.open(str(image)).convert("RGB") - - labels = [flavors, mediums, artists, trendings, movements] - - prompt = interrogate( - image, - clip_model_name, - clip_preprocess, - clip_model, - self.blip_model, - *labels, - ) - - return prompt - - -class LabelTable: - def __init__(self, labels, desc, clip_model_name, clip_model): - self.labels = labels - self.embeds = [] - - hash = hashlib.sha256(",".join(labels).encode()).hexdigest() - - os.makedirs("./cache", exist_ok=True) - cache_filepath = f"./cache/{desc}.pkl" - if desc is not None and os.path.exists(cache_filepath): - with open(cache_filepath, "rb") as f: - data = pickle.load(f) - if data.get("hash") == hash and data.get("model") == clip_model_name: - self.labels = data["labels"] - self.embeds = data["embeds"] - - if len(self.labels) != len(self.embeds): - self.embeds = [] - chunks = np.array_split(self.labels, max(1, len(self.labels) / chunk_size)) - for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None): - text_tokens = clip.tokenize(chunk).cuda() - with torch.no_grad(): - text_features = clip_model.encode_text(text_tokens).float() - text_features /= text_features.norm(dim=-1, keepdim=True) - text_features = text_features.half().cpu().numpy() - for i in range(text_features.shape[0]): - self.embeds.append(text_features[i]) - - with open(cache_filepath, "wb") as f: - pickle.dump( - { - "labels": self.labels, - "embeds": self.embeds, - "hash": hash, - "model": clip_model_name, - }, - f, - ) - - def _rank(self, image_features, text_embeds, device="cuda", top_count=1): - top_count = min(top_count, len(text_embeds)) - similarity = torch.zeros((1, len(text_embeds))).to(device) - text_embeds = ( - torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(device) - ) - for i in range(image_features.shape[0]): - similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax( - dim=-1 - ) - _, top_labels = similarity.cpu().topk(top_count, dim=-1) - return [top_labels[0][i].numpy() for i in range(top_count)] - - def rank(self, image_features, top_count=1): - if len(self.labels) <= chunk_size: - tops = self._rank(image_features, self.embeds, top_count=top_count) - return [self.labels[i] for i in tops] - - num_chunks = int(math.ceil(len(self.labels) / chunk_size)) - keep_per_chunk = int(chunk_size / num_chunks) - - top_labels, top_embeds = [], [] - for chunk_idx in tqdm(range(num_chunks)): - start = chunk_idx * chunk_size - stop = min(start + chunk_size, len(self.embeds)) - tops = self._rank( - image_features, self.embeds[start:stop], top_count=keep_per_chunk - ) - top_labels.extend([self.labels[start + i] for i in tops]) - top_embeds.extend([self.embeds[start + i] for i in tops]) - - tops = self._rank(image_features, top_embeds, top_count=top_count) - return [top_labels[i] for i in tops] - - -def generate_caption(pil_image, blip_model, device="cuda"): - gpu_image = ( - transforms.Compose( - [ - transforms.Resize( - (blip_image_eval_size, blip_image_eval_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - (0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711), - ), - ] - )(pil_image) - .unsqueeze(0) - .to(device) - ) - - with torch.no_grad(): - caption = blip_model.generate( - gpu_image, sample=False, num_beams=3, max_length=20, min_length=5 - ) - return caption[0] - - -def rank_top(image_features, text_array, clip_model, device="cuda"): - text_tokens = clip.tokenize([text for text in text_array]).cuda() - with torch.no_grad(): - text_features = clip_model.encode_text(text_tokens).float() - text_features /= text_features.norm(dim=-1, keepdim=True) - - similarity = torch.zeros((1, len(text_array)), device=device) - for i in range(image_features.shape[0]): - similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) - - _, top_labels = similarity.cpu().topk(1, dim=-1) - return text_array[top_labels[0][0].numpy()] - - -def similarity(image_features, text, clip_model): - text_tokens = clip.tokenize([text]).cuda() - with torch.no_grad(): - text_features = clip_model.encode_text(text_tokens).float() - text_features /= text_features.norm(dim=-1, keepdim=True) - similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T - return similarity[0][0] - - -def load_list(filename): - with open(filename, "r", encoding="utf-8", errors="replace") as f: - items = [line.strip() for line in f.readlines()] - return items - - -def interrogate(image, clip_model_name, clip_preprocess, clip_model, blip_model, *args): - flavors, mediums, artists, trendings, movements = args - caption = generate_caption(image, blip_model) - - images = clip_preprocess(image).unsqueeze(0).cuda() - with torch.no_grad(): - image_features = clip_model.encode_image(images).float() - image_features /= image_features.norm(dim=-1, keepdim=True) - - flaves = flavors.rank(image_features, flavor_intermediate_count) - best_medium = mediums.rank(image_features, 1)[0] - best_artist = artists.rank(image_features, 1)[0] - best_trending = trendings.rank(image_features, 1)[0] - best_movement = movements.rank(image_features, 1)[0] - - best_prompt = caption - best_sim = similarity(image_features, best_prompt, clip_model) - - def check(addition): - nonlocal best_prompt, best_sim - prompt = best_prompt + ", " + addition - sim = similarity(image_features, prompt, clip_model) - if sim > best_sim: - best_sim = sim - best_prompt = prompt - return True - return False - - def check_multi_batch(opts): - nonlocal best_prompt, best_sim - prompts = [] - for i in range(2 ** len(opts)): - prompt = best_prompt - for bit in range(len(opts)): - if i & (1 << bit): - prompt += ", " + opts[bit] - prompts.append(prompt) - - t = LabelTable(prompts, None, clip_model_name, clip_model) - best_prompt = t.rank(image_features, 1)[0] - best_sim = similarity(image_features, best_prompt, clip_model) - - check_multi_batch([best_medium, best_artist, best_trending, best_movement]) - - extended_flavors = set(flaves) - for _ in tqdm(range(25), desc="Flavor chain"): - try: - best = rank_top( - image_features, - [f"{best_prompt}, {f}" for f in extended_flavors], - clip_model, - ) - flave = best[len(best_prompt) + 2 :] - if not check(flave): - break - extended_flavors.remove(flave) - except: - # exceeded max prompt length - break - - return best_prompt + return self.ci.interrogate(image)