Browse Source

Update Replicate cog to use clip_interrogator library

replicate
pharmapsychotic 2 years ago
parent
commit
55b1770386
  1. 1
      .gitignore
  2. 3
      README.md
  3. 10
      cog.yaml
  4. 335
      predict.py

1
.gitignore vendored

@ -1,4 +1,5 @@
*.pyc *.pyc
.cog/
.vscode/ .vscode/
cache/ cache/
clip-interrogator/ clip-interrogator/

3
README.md

@ -8,9 +8,6 @@ Run Version 2 on Colab, HuggingFace, and Replicate!
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator)
[![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator)
<br> <br>
Version 1 still available in Colab for comparing different CLIP models Version 1 still available in Colab for comparing different CLIP models

10
cog.yaml

@ -13,16 +13,8 @@ build:
- "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
- "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
run: run:
- pip install -e git+https://github.com/pharmapsychotic/BLIP.git@main#egg=blip - pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip
- pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip - pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50.pt" "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN101.pt" "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x4.pt" "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x16.pt" "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x64.pt" "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-32.pt" "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-16.pt" "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt" - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14-336px.pt" "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
predict: "predict.py:Predictor" predict: "predict.py:Predictor"

335
predict.py

@ -1,341 +1,18 @@
import sys import sys
sys.path.append("src/clip")
sys.path.append("src/blip")
import os
import hashlib
import math
import numpy as np
import pickle
from tqdm import tqdm
from PIL import Image from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import clip
from models.blip import blip_decoder
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path
sys.path.extend(["src/clip", "src/blip"])
DATA_PATH = "data" from clip_interrogator import Interrogator, Config
chunk_size = 2048
flavor_intermediate_count = 2048
blip_image_eval_size = 384
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
"""Load the model into memory to make running multiple predictions efficient""" config = Config(device="cuda:0", clip_model_name='ViT-L/14')
self.ci = Interrogator(config)
self.device = "cuda:0"
print("Loading BLIP model...")
self.blip_model = blip_decoder(
pretrained="weights/model_large_caption.pth", # downloaded with wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth
image_size=blip_image_eval_size,
vit="large",
med_config="src/blip/configs/med_config.json",
)
self.blip_model.eval()
self.blip_model = self.blip_model.to(self.device)
print("Loading CLIP model...")
self.clip_models, self.clip_preprocess = {}, {}
for clip_model_name in [
"ViT-B/32",
"ViT-B/16",
"ViT-L/14",
"ViT-L/14@336px",
"RN101",
"RN50",
"RN50x4",
"RN50x16",
"RN50x64",
]:
(
self.clip_models[clip_model_name],
self.clip_preprocess[clip_model_name],
) = clip.load(clip_model_name, device=self.device)
self.clip_models[clip_model_name].cuda().eval()
sites = [
"Artstation",
"behance",
"cg society",
"cgsociety",
"deviantart",
"dribble",
"flickr",
"instagram",
"pexels",
"pinterest",
"pixabay",
"pixiv",
"polycount",
"reddit",
"shutterstock",
"tumblr",
"unsplash",
"zbrush central",
]
self.trending_list = [site for site in sites]
self.trending_list.extend(["trending on " + site for site in sites])
self.trending_list.extend(["featured on " + site for site in sites])
self.trending_list.extend([site + " contest winner" for site in sites])
raw_artists = load_list(f"{DATA_PATH}/artists.txt")
self.artists = [f"by {a}" for a in raw_artists]
self.artists.extend([f"inspired by {a}" for a in raw_artists])
def predict( def predict(self, image: Path = Input(description="Input image")) -> str:
self,
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L/14",
choices=[
"ViT-B/32",
"ViT-B/16",
"ViT-L/14",
"ViT-L/14@336px",
"RN101",
"RN50",
"RN50x4",
"RN50x16",
"RN50x64",
],
description="Choose a clip model.",
),
) -> str:
"""Run a single prediction on the model""" """Run a single prediction on the model"""
clip_model = self.clip_models[clip_model_name]
clip_preprocess = self.clip_preprocess[clip_model_name]
artists = LabelTable(self.artists, "artists", clip_model_name, clip_model)
flavors = LabelTable(
load_list(f"{DATA_PATH}/flavors.txt"),
"flavors",
clip_model_name,
clip_model,
)
mediums = LabelTable(
load_list(f"{DATA_PATH}/mediums.txt"),
"mediums",
clip_model_name,
clip_model,
)
movements = LabelTable(
load_list(f"{DATA_PATH}/movements.txt"),
"movements",
clip_model_name,
clip_model,
)
trendings = LabelTable(
self.trending_list, "trendings", clip_model_name, clip_model
)
image = Image.open(str(image)).convert("RGB") image = Image.open(str(image)).convert("RGB")
return self.ci.interrogate(image)
labels = [flavors, mediums, artists, trendings, movements]
prompt = interrogate(
image,
clip_model_name,
clip_preprocess,
clip_model,
self.blip_model,
*labels,
)
return prompt
class LabelTable:
def __init__(self, labels, desc, clip_model_name, clip_model):
self.labels = labels
self.embeds = []
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
os.makedirs("./cache", exist_ok=True)
cache_filepath = f"./cache/{desc}.pkl"
if desc is not None and os.path.exists(cache_filepath):
with open(cache_filepath, "rb") as f:
data = pickle.load(f)
if data.get("hash") == hash and data.get("model") == clip_model_name:
self.labels = data["labels"]
self.embeds = data["embeds"]
if len(self.labels) != len(self.embeds):
self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels) / chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
text_tokens = clip.tokenize(chunk).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.half().cpu().numpy()
for i in range(text_features.shape[0]):
self.embeds.append(text_features[i])
with open(cache_filepath, "wb") as f:
pickle.dump(
{
"labels": self.labels,
"embeds": self.embeds,
"hash": hash,
"model": clip_model_name,
},
f,
)
def _rank(self, image_features, text_embeds, device="cuda", top_count=1):
top_count = min(top_count, len(text_embeds))
similarity = torch.zeros((1, len(text_embeds))).to(device)
text_embeds = (
torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(device)
)
for i in range(image_features.shape[0]):
similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(
dim=-1
)
_, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features, top_count=1):
if len(self.labels) <= chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count)
return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels) / chunk_size))
keep_per_chunk = int(chunk_size / num_chunks)
top_labels, top_embeds = [], []
for chunk_idx in tqdm(range(num_chunks)):
start = chunk_idx * chunk_size
stop = min(start + chunk_size, len(self.embeds))
tops = self._rank(
image_features, self.embeds[start:stop], top_count=keep_per_chunk
)
top_labels.extend([self.labels[start + i] for i in tops])
top_embeds.extend([self.embeds[start + i] for i in tops])
tops = self._rank(image_features, top_embeds, top_count=top_count)
return [top_labels[i] for i in tops]
def generate_caption(pil_image, blip_model, device="cuda"):
gpu_image = (
transforms.Compose(
[
transforms.Resize(
(blip_image_eval_size, blip_image_eval_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)(pil_image)
.unsqueeze(0)
.to(device)
)
with torch.no_grad():
caption = blip_model.generate(
gpu_image, sample=False, num_beams=3, max_length=20, min_length=5
)
return caption[0]
def rank_top(image_features, text_array, clip_model, device="cuda"):
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array)), device=device)
for i in range(image_features.shape[0]):
similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
_, top_labels = similarity.cpu().topk(1, dim=-1)
return text_array[top_labels[0][0].numpy()]
def similarity(image_features, text, clip_model):
text_tokens = clip.tokenize([text]).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
return similarity[0][0]
def load_list(filename):
with open(filename, "r", encoding="utf-8", errors="replace") as f:
items = [line.strip() for line in f.readlines()]
return items
def interrogate(image, clip_model_name, clip_preprocess, clip_model, blip_model, *args):
flavors, mediums, artists, trendings, movements = args
caption = generate_caption(image, blip_model)
images = clip_preprocess(image).unsqueeze(0).cuda()
with torch.no_grad():
image_features = clip_model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
flaves = flavors.rank(image_features, flavor_intermediate_count)
best_medium = mediums.rank(image_features, 1)[0]
best_artist = artists.rank(image_features, 1)[0]
best_trending = trendings.rank(image_features, 1)[0]
best_movement = movements.rank(image_features, 1)[0]
best_prompt = caption
best_sim = similarity(image_features, best_prompt, clip_model)
def check(addition):
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = similarity(image_features, prompt, clip_model)
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
def check_multi_batch(opts):
nonlocal best_prompt, best_sim
prompts = []
for i in range(2 ** len(opts)):
prompt = best_prompt
for bit in range(len(opts)):
if i & (1 << bit):
prompt += ", " + opts[bit]
prompts.append(prompt)
t = LabelTable(prompts, None, clip_model_name, clip_model)
best_prompt = t.rank(image_features, 1)[0]
best_sim = similarity(image_features, best_prompt, clip_model)
check_multi_batch([best_medium, best_artist, best_trending, best_movement])
extended_flavors = set(flaves)
for _ in tqdm(range(25), desc="Flavor chain"):
try:
best = rank_top(
image_features,
[f"{best_prompt}, {f}" for f in extended_flavors],
clip_model,
)
flave = best[len(best_prompt) + 2 :]
if not check(flave):
break
extended_flavors.remove(flave)
except:
# exceeded max prompt length
break
return best_prompt

Loading…
Cancel
Save