Browse Source

replicate demo

pull/11/head
Chenxi 2 years ago
parent
commit
11a0087004
  1. 2
      README.md
  2. 28
      cog.yaml
  3. 341
      predict.py

2
README.md

@ -4,6 +4,8 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) Version 1
[![Replicate](https://replicate.com/cjwbw/clip-interrogator/badge)](https://replicate.com/cjwbw/clip-interrogator)
The CLIP Interrogator uses the OpenAI CLIP models to test a given image against a variety of artists, mediums, and styles to study how the different models see the content of the image. It also combines the results with BLIP caption to suggest a text prompt to create more images similar to what was given.

28
cog.yaml

@ -0,0 +1,28 @@
build:
gpu: true
cuda: "11.3"
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "ipython==8.4.0"
- "fairscale==0.4.12"
- "transformers==4.21.2"
- "ftfy==6.1.1"
- "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
- "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
run:
- pip install -e git+https://github.com/pharmapsychotic/BLIP.git@main#egg=blip
- pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50.pt" "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN101.pt" "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x4.pt" "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x16.pt" "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x64.pt" "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-32.pt" "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-16.pt" "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14-336px.pt" "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
predict: "predict.py:Predictor"

341
predict.py

@ -0,0 +1,341 @@
import sys
sys.path.append("src/clip")
sys.path.append("src/blip")
import os
import hashlib
import math
import numpy as np
import pickle
from tqdm import tqdm
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import clip
from models.blip import blip_decoder
from cog import BasePredictor, Input, Path
DATA_PATH = "data"
chunk_size = 2048
flavor_intermediate_count = 2048
blip_image_eval_size = 384
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.device = "cuda:0"
print("Loading BLIP model...")
self.blip_model = blip_decoder(
pretrained="weights/model_large_caption.pth", # downloaded with wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth
image_size=blip_image_eval_size,
vit="large",
med_config="src/blip/configs/med_config.json",
)
self.blip_model.eval()
self.blip_model = self.blip_model.to(self.device)
print("Loading CLIP model...")
self.clip_models, self.clip_preprocess = {}, {}
for clip_model_name in [
"ViT-B/32",
"ViT-B/16",
"ViT-L/14",
"ViT-L/14@336px",
"RN101",
"RN50",
"RN50x4",
"RN50x16",
"RN50x64",
]:
(
self.clip_models[clip_model_name],
self.clip_preprocess[clip_model_name],
) = clip.load(clip_model_name, device=self.device)
self.clip_models[clip_model_name].cuda().eval()
sites = [
"Artstation",
"behance",
"cg society",
"cgsociety",
"deviantart",
"dribble",
"flickr",
"instagram",
"pexels",
"pinterest",
"pixabay",
"pixiv",
"polycount",
"reddit",
"shutterstock",
"tumblr",
"unsplash",
"zbrush central",
]
self.trending_list = [site for site in sites]
self.trending_list.extend(["trending on " + site for site in sites])
self.trending_list.extend(["featured on " + site for site in sites])
self.trending_list.extend([site + " contest winner" for site in sites])
raw_artists = load_list(f"{DATA_PATH}/artists.txt")
self.artists = [f"by {a}" for a in raw_artists]
self.artists.extend([f"inspired by {a}" for a in raw_artists])
def predict(
self,
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L/14",
choices=[
"ViT-B/32",
"ViT-B/16",
"ViT-L/14",
"ViT-L/14@336px",
"RN101",
"RN50",
"RN50x4",
"RN50x16",
"RN50x64",
],
description="Choose a clip model.",
),
) -> str:
"""Run a single prediction on the model"""
clip_model = self.clip_models[clip_model_name]
clip_preprocess = self.clip_preprocess[clip_model_name]
artists = LabelTable(self.artists, "artists", clip_model_name, clip_model)
flavors = LabelTable(
load_list(f"{DATA_PATH}/flavors.txt"),
"flavors",
clip_model_name,
clip_model,
)
mediums = LabelTable(
load_list(f"{DATA_PATH}/mediums.txt"),
"mediums",
clip_model_name,
clip_model,
)
movements = LabelTable(
load_list(f"{DATA_PATH}/movements.txt"),
"movements",
clip_model_name,
clip_model,
)
trendings = LabelTable(
self.trending_list, "trendings", clip_model_name, clip_model
)
image = Image.open(str(image)).convert("RGB")
labels = [flavors, mediums, artists, trendings, movements]
prompt = interrogate(
image,
clip_model_name,
clip_preprocess,
clip_model,
self.blip_model,
*labels,
)
return prompt
class LabelTable:
def __init__(self, labels, desc, clip_model_name, clip_model):
self.labels = labels
self.embeds = []
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
os.makedirs("./cache", exist_ok=True)
cache_filepath = f"./cache/{desc}.pkl"
if desc is not None and os.path.exists(cache_filepath):
with open(cache_filepath, "rb") as f:
data = pickle.load(f)
if data.get("hash") == hash and data.get("model") == clip_model_name:
self.labels = data["labels"]
self.embeds = data["embeds"]
if len(self.labels) != len(self.embeds):
self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels) / chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
text_tokens = clip.tokenize(chunk).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.half().cpu().numpy()
for i in range(text_features.shape[0]):
self.embeds.append(text_features[i])
with open(cache_filepath, "wb") as f:
pickle.dump(
{
"labels": self.labels,
"embeds": self.embeds,
"hash": hash,
"model": clip_model_name,
},
f,
)
def _rank(self, image_features, text_embeds, device="cuda", top_count=1):
top_count = min(top_count, len(text_embeds))
similarity = torch.zeros((1, len(text_embeds))).to(device)
text_embeds = (
torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(device)
)
for i in range(image_features.shape[0]):
similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(
dim=-1
)
_, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features, top_count=1):
if len(self.labels) <= chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count)
return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels) / chunk_size))
keep_per_chunk = int(chunk_size / num_chunks)
top_labels, top_embeds = [], []
for chunk_idx in tqdm(range(num_chunks)):
start = chunk_idx * chunk_size
stop = min(start + chunk_size, len(self.embeds))
tops = self._rank(
image_features, self.embeds[start:stop], top_count=keep_per_chunk
)
top_labels.extend([self.labels[start + i] for i in tops])
top_embeds.extend([self.embeds[start + i] for i in tops])
tops = self._rank(image_features, top_embeds, top_count=top_count)
return [top_labels[i] for i in tops]
def generate_caption(pil_image, blip_model, device="cuda"):
gpu_image = (
transforms.Compose(
[
transforms.Resize(
(blip_image_eval_size, blip_image_eval_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)(pil_image)
.unsqueeze(0)
.to(device)
)
with torch.no_grad():
caption = blip_model.generate(
gpu_image, sample=False, num_beams=3, max_length=20, min_length=5
)
return caption[0]
def rank_top(image_features, text_array, clip_model, device="cuda"):
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array)), device=device)
for i in range(image_features.shape[0]):
similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
_, top_labels = similarity.cpu().topk(1, dim=-1)
return text_array[top_labels[0][0].numpy()]
def similarity(image_features, text, clip_model):
text_tokens = clip.tokenize([text]).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
return similarity[0][0]
def load_list(filename):
with open(filename, "r", encoding="utf-8", errors="replace") as f:
items = [line.strip() for line in f.readlines()]
return items
def interrogate(image, clip_model_name, clip_preprocess, clip_model, blip_model, *args):
flavors, mediums, artists, trendings, movements = args
caption = generate_caption(image, blip_model)
images = clip_preprocess(image).unsqueeze(0).cuda()
with torch.no_grad():
image_features = clip_model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
flaves = flavors.rank(image_features, flavor_intermediate_count)
best_medium = mediums.rank(image_features, 1)[0]
best_artist = artists.rank(image_features, 1)[0]
best_trending = trendings.rank(image_features, 1)[0]
best_movement = movements.rank(image_features, 1)[0]
best_prompt = caption
best_sim = similarity(image_features, best_prompt, clip_model)
def check(addition):
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = similarity(image_features, prompt, clip_model)
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
def check_multi_batch(opts):
nonlocal best_prompt, best_sim
prompts = []
for i in range(2 ** len(opts)):
prompt = best_prompt
for bit in range(len(opts)):
if i & (1 << bit):
prompt += ", " + opts[bit]
prompts.append(prompt)
t = LabelTable(prompts, None, clip_model_name, clip_model)
best_prompt = t.rank(image_features, 1)[0]
best_sim = similarity(image_features, best_prompt, clip_model)
check_multi_batch([best_medium, best_artist, best_trending, best_movement])
extended_flavors = set(flaves)
for _ in tqdm(range(25), desc="Flavor chain"):
try:
best = rank_top(
image_features,
[f"{best_prompt}, {f}" for f in extended_flavors],
clip_model,
)
flave = best[len(best_prompt) + 2 :]
if not check(flave):
break
extended_flavors.remove(flave)
except:
# exceeded max prompt length
break
return best_prompt
Loading…
Cancel
Save