Browse Source

Update Replicate cog to use clip_interrogator library

replicate
pharmapsychotic 2 years ago
parent
commit
55b1770386
  1. 1
      .gitignore
  2. 3
      README.md
  3. 10
      cog.yaml
  4. 335
      predict.py

1
.gitignore vendored

@ -1,4 +1,5 @@
*.pyc
.cog/
.vscode/
cache/
clip-interrogator/

3
README.md

@ -8,9 +8,6 @@ Run Version 2 on Colab, HuggingFace, and Replicate!
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator)
[![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator)
<br>
Version 1 still available in Colab for comparing different CLIP models

10
cog.yaml

@ -13,16 +13,8 @@ build:
- "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
- "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
run:
- pip install -e git+https://github.com/pharmapsychotic/BLIP.git@main#egg=blip
- pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip
- pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50.pt" "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN101.pt" "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x4.pt" "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x16.pt" "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x64.pt" "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-32.pt" "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-16.pt" "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14-336px.pt" "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
predict: "predict.py:Predictor"

335
predict.py

@ -1,341 +1,18 @@
import sys
sys.path.append("src/clip")
sys.path.append("src/blip")
import os
import hashlib
import math
import numpy as np
import pickle
from tqdm import tqdm
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import clip
from models.blip import blip_decoder
from cog import BasePredictor, Input, Path
sys.path.extend(["src/clip", "src/blip"])
DATA_PATH = "data"
chunk_size = 2048
flavor_intermediate_count = 2048
blip_image_eval_size = 384
from clip_interrogator import Interrogator, Config
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.device = "cuda:0"
print("Loading BLIP model...")
self.blip_model = blip_decoder(
pretrained="weights/model_large_caption.pth", # downloaded with wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth
image_size=blip_image_eval_size,
vit="large",
med_config="src/blip/configs/med_config.json",
)
self.blip_model.eval()
self.blip_model = self.blip_model.to(self.device)
print("Loading CLIP model...")
self.clip_models, self.clip_preprocess = {}, {}
for clip_model_name in [
"ViT-B/32",
"ViT-B/16",
"ViT-L/14",
"ViT-L/14@336px",
"RN101",
"RN50",
"RN50x4",
"RN50x16",
"RN50x64",
]:
(
self.clip_models[clip_model_name],
self.clip_preprocess[clip_model_name],
) = clip.load(clip_model_name, device=self.device)
self.clip_models[clip_model_name].cuda().eval()
sites = [
"Artstation",
"behance",
"cg society",
"cgsociety",
"deviantart",
"dribble",
"flickr",
"instagram",
"pexels",
"pinterest",
"pixabay",
"pixiv",
"polycount",
"reddit",
"shutterstock",
"tumblr",
"unsplash",
"zbrush central",
]
self.trending_list = [site for site in sites]
self.trending_list.extend(["trending on " + site for site in sites])
self.trending_list.extend(["featured on " + site for site in sites])
self.trending_list.extend([site + " contest winner" for site in sites])
raw_artists = load_list(f"{DATA_PATH}/artists.txt")
self.artists = [f"by {a}" for a in raw_artists]
self.artists.extend([f"inspired by {a}" for a in raw_artists])
config = Config(device="cuda:0", clip_model_name='ViT-L/14')
self.ci = Interrogator(config)
def predict(
self,
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L/14",
choices=[
"ViT-B/32",
"ViT-B/16",
"ViT-L/14",
"ViT-L/14@336px",
"RN101",
"RN50",
"RN50x4",
"RN50x16",
"RN50x64",
],
description="Choose a clip model.",
),
) -> str:
def predict(self, image: Path = Input(description="Input image")) -> str:
"""Run a single prediction on the model"""
clip_model = self.clip_models[clip_model_name]
clip_preprocess = self.clip_preprocess[clip_model_name]
artists = LabelTable(self.artists, "artists", clip_model_name, clip_model)
flavors = LabelTable(
load_list(f"{DATA_PATH}/flavors.txt"),
"flavors",
clip_model_name,
clip_model,
)
mediums = LabelTable(
load_list(f"{DATA_PATH}/mediums.txt"),
"mediums",
clip_model_name,
clip_model,
)
movements = LabelTable(
load_list(f"{DATA_PATH}/movements.txt"),
"movements",
clip_model_name,
clip_model,
)
trendings = LabelTable(
self.trending_list, "trendings", clip_model_name, clip_model
)
image = Image.open(str(image)).convert("RGB")
labels = [flavors, mediums, artists, trendings, movements]
prompt = interrogate(
image,
clip_model_name,
clip_preprocess,
clip_model,
self.blip_model,
*labels,
)
return prompt
class LabelTable:
def __init__(self, labels, desc, clip_model_name, clip_model):
self.labels = labels
self.embeds = []
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
os.makedirs("./cache", exist_ok=True)
cache_filepath = f"./cache/{desc}.pkl"
if desc is not None and os.path.exists(cache_filepath):
with open(cache_filepath, "rb") as f:
data = pickle.load(f)
if data.get("hash") == hash and data.get("model") == clip_model_name:
self.labels = data["labels"]
self.embeds = data["embeds"]
if len(self.labels) != len(self.embeds):
self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels) / chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
text_tokens = clip.tokenize(chunk).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.half().cpu().numpy()
for i in range(text_features.shape[0]):
self.embeds.append(text_features[i])
with open(cache_filepath, "wb") as f:
pickle.dump(
{
"labels": self.labels,
"embeds": self.embeds,
"hash": hash,
"model": clip_model_name,
},
f,
)
def _rank(self, image_features, text_embeds, device="cuda", top_count=1):
top_count = min(top_count, len(text_embeds))
similarity = torch.zeros((1, len(text_embeds))).to(device)
text_embeds = (
torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(device)
)
for i in range(image_features.shape[0]):
similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(
dim=-1
)
_, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features, top_count=1):
if len(self.labels) <= chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count)
return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels) / chunk_size))
keep_per_chunk = int(chunk_size / num_chunks)
top_labels, top_embeds = [], []
for chunk_idx in tqdm(range(num_chunks)):
start = chunk_idx * chunk_size
stop = min(start + chunk_size, len(self.embeds))
tops = self._rank(
image_features, self.embeds[start:stop], top_count=keep_per_chunk
)
top_labels.extend([self.labels[start + i] for i in tops])
top_embeds.extend([self.embeds[start + i] for i in tops])
tops = self._rank(image_features, top_embeds, top_count=top_count)
return [top_labels[i] for i in tops]
def generate_caption(pil_image, blip_model, device="cuda"):
gpu_image = (
transforms.Compose(
[
transforms.Resize(
(blip_image_eval_size, blip_image_eval_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)(pil_image)
.unsqueeze(0)
.to(device)
)
with torch.no_grad():
caption = blip_model.generate(
gpu_image, sample=False, num_beams=3, max_length=20, min_length=5
)
return caption[0]
def rank_top(image_features, text_array, clip_model, device="cuda"):
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array)), device=device)
for i in range(image_features.shape[0]):
similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
_, top_labels = similarity.cpu().topk(1, dim=-1)
return text_array[top_labels[0][0].numpy()]
def similarity(image_features, text, clip_model):
text_tokens = clip.tokenize([text]).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
return similarity[0][0]
def load_list(filename):
with open(filename, "r", encoding="utf-8", errors="replace") as f:
items = [line.strip() for line in f.readlines()]
return items
def interrogate(image, clip_model_name, clip_preprocess, clip_model, blip_model, *args):
flavors, mediums, artists, trendings, movements = args
caption = generate_caption(image, blip_model)
images = clip_preprocess(image).unsqueeze(0).cuda()
with torch.no_grad():
image_features = clip_model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
flaves = flavors.rank(image_features, flavor_intermediate_count)
best_medium = mediums.rank(image_features, 1)[0]
best_artist = artists.rank(image_features, 1)[0]
best_trending = trendings.rank(image_features, 1)[0]
best_movement = movements.rank(image_features, 1)[0]
best_prompt = caption
best_sim = similarity(image_features, best_prompt, clip_model)
def check(addition):
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = similarity(image_features, prompt, clip_model)
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
def check_multi_batch(opts):
nonlocal best_prompt, best_sim
prompts = []
for i in range(2 ** len(opts)):
prompt = best_prompt
for bit in range(len(opts)):
if i & (1 << bit):
prompt += ", " + opts[bit]
prompts.append(prompt)
t = LabelTable(prompts, None, clip_model_name, clip_model)
best_prompt = t.rank(image_features, 1)[0]
best_sim = similarity(image_features, best_prompt, clip_model)
check_multi_batch([best_medium, best_artist, best_trending, best_movement])
extended_flavors = set(flaves)
for _ in tqdm(range(25), desc="Flavor chain"):
try:
best = rank_top(
image_features,
[f"{best_prompt}, {f}" for f in extended_flavors],
clip_model,
)
flave = best[len(best_prompt) + 2 :]
if not check(flave):
break
extended_flavors.remove(flave)
except:
# exceeded max prompt length
break
return best_prompt
return self.ci.interrogate(image)

Loading…
Cancel
Save