3 changed files with 371 additions and 0 deletions
@ -0,0 +1,28 @@ |
|||||||
|
build: |
||||||
|
gpu: true |
||||||
|
cuda: "11.3" |
||||||
|
python_version: "3.8" |
||||||
|
system_packages: |
||||||
|
- "libgl1-mesa-glx" |
||||||
|
- "libglib2.0-0" |
||||||
|
python_packages: |
||||||
|
- "ipython==8.4.0" |
||||||
|
- "fairscale==0.4.12" |
||||||
|
- "transformers==4.21.2" |
||||||
|
- "ftfy==6.1.1" |
||||||
|
- "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" |
||||||
|
- "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" |
||||||
|
run: |
||||||
|
- pip install -e git+https://github.com/pharmapsychotic/BLIP.git@main#egg=blip |
||||||
|
- pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50.pt" "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt" |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN101.pt" "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt" |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x4.pt" "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt" |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x16.pt" "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt" |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/RN50x64.pt" "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt" |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-32.pt" "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-B-16.pt" "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt" |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt" |
||||||
|
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14-336px.pt" "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" |
||||||
|
|
||||||
|
predict: "predict.py:Predictor" |
@ -0,0 +1,341 @@ |
|||||||
|
import sys |
||||||
|
|
||||||
|
sys.path.append("src/clip") |
||||||
|
sys.path.append("src/blip") |
||||||
|
|
||||||
|
import os |
||||||
|
import hashlib |
||||||
|
import math |
||||||
|
import numpy as np |
||||||
|
import pickle |
||||||
|
from tqdm import tqdm |
||||||
|
from PIL import Image |
||||||
|
import torch |
||||||
|
from torchvision import transforms |
||||||
|
from torchvision.transforms.functional import InterpolationMode |
||||||
|
import clip |
||||||
|
from models.blip import blip_decoder |
||||||
|
from cog import BasePredictor, Input, Path |
||||||
|
|
||||||
|
|
||||||
|
DATA_PATH = "data" |
||||||
|
chunk_size = 2048 |
||||||
|
flavor_intermediate_count = 2048 |
||||||
|
blip_image_eval_size = 384 |
||||||
|
|
||||||
|
|
||||||
|
class Predictor(BasePredictor): |
||||||
|
def setup(self): |
||||||
|
"""Load the model into memory to make running multiple predictions efficient""" |
||||||
|
|
||||||
|
self.device = "cuda:0" |
||||||
|
|
||||||
|
print("Loading BLIP model...") |
||||||
|
self.blip_model = blip_decoder( |
||||||
|
pretrained="weights/model_large_caption.pth", # downloaded with wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth |
||||||
|
image_size=blip_image_eval_size, |
||||||
|
vit="large", |
||||||
|
med_config="src/blip/configs/med_config.json", |
||||||
|
) |
||||||
|
self.blip_model.eval() |
||||||
|
self.blip_model = self.blip_model.to(self.device) |
||||||
|
|
||||||
|
print("Loading CLIP model...") |
||||||
|
self.clip_models, self.clip_preprocess = {}, {} |
||||||
|
for clip_model_name in [ |
||||||
|
"ViT-B/32", |
||||||
|
"ViT-B/16", |
||||||
|
"ViT-L/14", |
||||||
|
"ViT-L/14@336px", |
||||||
|
"RN101", |
||||||
|
"RN50", |
||||||
|
"RN50x4", |
||||||
|
"RN50x16", |
||||||
|
"RN50x64", |
||||||
|
]: |
||||||
|
( |
||||||
|
self.clip_models[clip_model_name], |
||||||
|
self.clip_preprocess[clip_model_name], |
||||||
|
) = clip.load(clip_model_name, device=self.device) |
||||||
|
self.clip_models[clip_model_name].cuda().eval() |
||||||
|
|
||||||
|
sites = [ |
||||||
|
"Artstation", |
||||||
|
"behance", |
||||||
|
"cg society", |
||||||
|
"cgsociety", |
||||||
|
"deviantart", |
||||||
|
"dribble", |
||||||
|
"flickr", |
||||||
|
"instagram", |
||||||
|
"pexels", |
||||||
|
"pinterest", |
||||||
|
"pixabay", |
||||||
|
"pixiv", |
||||||
|
"polycount", |
||||||
|
"reddit", |
||||||
|
"shutterstock", |
||||||
|
"tumblr", |
||||||
|
"unsplash", |
||||||
|
"zbrush central", |
||||||
|
] |
||||||
|
self.trending_list = [site for site in sites] |
||||||
|
self.trending_list.extend(["trending on " + site for site in sites]) |
||||||
|
self.trending_list.extend(["featured on " + site for site in sites]) |
||||||
|
self.trending_list.extend([site + " contest winner" for site in sites]) |
||||||
|
raw_artists = load_list(f"{DATA_PATH}/artists.txt") |
||||||
|
self.artists = [f"by {a}" for a in raw_artists] |
||||||
|
self.artists.extend([f"inspired by {a}" for a in raw_artists]) |
||||||
|
|
||||||
|
def predict( |
||||||
|
self, |
||||||
|
image: Path = Input(description="Input image"), |
||||||
|
clip_model_name: str = Input( |
||||||
|
default="ViT-L/14", |
||||||
|
choices=[ |
||||||
|
"ViT-B/32", |
||||||
|
"ViT-B/16", |
||||||
|
"ViT-L/14", |
||||||
|
"ViT-L/14@336px", |
||||||
|
"RN101", |
||||||
|
"RN50", |
||||||
|
"RN50x4", |
||||||
|
"RN50x16", |
||||||
|
"RN50x64", |
||||||
|
], |
||||||
|
description="Choose a clip model.", |
||||||
|
), |
||||||
|
) -> str: |
||||||
|
"""Run a single prediction on the model""" |
||||||
|
clip_model = self.clip_models[clip_model_name] |
||||||
|
clip_preprocess = self.clip_preprocess[clip_model_name] |
||||||
|
|
||||||
|
artists = LabelTable(self.artists, "artists", clip_model_name, clip_model) |
||||||
|
flavors = LabelTable( |
||||||
|
load_list(f"{DATA_PATH}/flavors.txt"), |
||||||
|
"flavors", |
||||||
|
clip_model_name, |
||||||
|
clip_model, |
||||||
|
) |
||||||
|
mediums = LabelTable( |
||||||
|
load_list(f"{DATA_PATH}/mediums.txt"), |
||||||
|
"mediums", |
||||||
|
clip_model_name, |
||||||
|
clip_model, |
||||||
|
) |
||||||
|
movements = LabelTable( |
||||||
|
load_list(f"{DATA_PATH}/movements.txt"), |
||||||
|
"movements", |
||||||
|
clip_model_name, |
||||||
|
clip_model, |
||||||
|
) |
||||||
|
trendings = LabelTable( |
||||||
|
self.trending_list, "trendings", clip_model_name, clip_model |
||||||
|
) |
||||||
|
|
||||||
|
image = Image.open(str(image)).convert("RGB") |
||||||
|
|
||||||
|
labels = [flavors, mediums, artists, trendings, movements] |
||||||
|
|
||||||
|
prompt = interrogate( |
||||||
|
image, |
||||||
|
clip_model_name, |
||||||
|
clip_preprocess, |
||||||
|
clip_model, |
||||||
|
self.blip_model, |
||||||
|
*labels, |
||||||
|
) |
||||||
|
|
||||||
|
return prompt |
||||||
|
|
||||||
|
|
||||||
|
class LabelTable: |
||||||
|
def __init__(self, labels, desc, clip_model_name, clip_model): |
||||||
|
self.labels = labels |
||||||
|
self.embeds = [] |
||||||
|
|
||||||
|
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() |
||||||
|
|
||||||
|
os.makedirs("./cache", exist_ok=True) |
||||||
|
cache_filepath = f"./cache/{desc}.pkl" |
||||||
|
if desc is not None and os.path.exists(cache_filepath): |
||||||
|
with open(cache_filepath, "rb") as f: |
||||||
|
data = pickle.load(f) |
||||||
|
if data.get("hash") == hash and data.get("model") == clip_model_name: |
||||||
|
self.labels = data["labels"] |
||||||
|
self.embeds = data["embeds"] |
||||||
|
|
||||||
|
if len(self.labels) != len(self.embeds): |
||||||
|
self.embeds = [] |
||||||
|
chunks = np.array_split(self.labels, max(1, len(self.labels) / chunk_size)) |
||||||
|
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None): |
||||||
|
text_tokens = clip.tokenize(chunk).cuda() |
||||||
|
with torch.no_grad(): |
||||||
|
text_features = clip_model.encode_text(text_tokens).float() |
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||||
|
text_features = text_features.half().cpu().numpy() |
||||||
|
for i in range(text_features.shape[0]): |
||||||
|
self.embeds.append(text_features[i]) |
||||||
|
|
||||||
|
with open(cache_filepath, "wb") as f: |
||||||
|
pickle.dump( |
||||||
|
{ |
||||||
|
"labels": self.labels, |
||||||
|
"embeds": self.embeds, |
||||||
|
"hash": hash, |
||||||
|
"model": clip_model_name, |
||||||
|
}, |
||||||
|
f, |
||||||
|
) |
||||||
|
|
||||||
|
def _rank(self, image_features, text_embeds, device="cuda", top_count=1): |
||||||
|
top_count = min(top_count, len(text_embeds)) |
||||||
|
similarity = torch.zeros((1, len(text_embeds))).to(device) |
||||||
|
text_embeds = ( |
||||||
|
torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(device) |
||||||
|
) |
||||||
|
for i in range(image_features.shape[0]): |
||||||
|
similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax( |
||||||
|
dim=-1 |
||||||
|
) |
||||||
|
_, top_labels = similarity.cpu().topk(top_count, dim=-1) |
||||||
|
return [top_labels[0][i].numpy() for i in range(top_count)] |
||||||
|
|
||||||
|
def rank(self, image_features, top_count=1): |
||||||
|
if len(self.labels) <= chunk_size: |
||||||
|
tops = self._rank(image_features, self.embeds, top_count=top_count) |
||||||
|
return [self.labels[i] for i in tops] |
||||||
|
|
||||||
|
num_chunks = int(math.ceil(len(self.labels) / chunk_size)) |
||||||
|
keep_per_chunk = int(chunk_size / num_chunks) |
||||||
|
|
||||||
|
top_labels, top_embeds = [], [] |
||||||
|
for chunk_idx in tqdm(range(num_chunks)): |
||||||
|
start = chunk_idx * chunk_size |
||||||
|
stop = min(start + chunk_size, len(self.embeds)) |
||||||
|
tops = self._rank( |
||||||
|
image_features, self.embeds[start:stop], top_count=keep_per_chunk |
||||||
|
) |
||||||
|
top_labels.extend([self.labels[start + i] for i in tops]) |
||||||
|
top_embeds.extend([self.embeds[start + i] for i in tops]) |
||||||
|
|
||||||
|
tops = self._rank(image_features, top_embeds, top_count=top_count) |
||||||
|
return [top_labels[i] for i in tops] |
||||||
|
|
||||||
|
|
||||||
|
def generate_caption(pil_image, blip_model, device="cuda"): |
||||||
|
gpu_image = ( |
||||||
|
transforms.Compose( |
||||||
|
[ |
||||||
|
transforms.Resize( |
||||||
|
(blip_image_eval_size, blip_image_eval_size), |
||||||
|
interpolation=InterpolationMode.BICUBIC, |
||||||
|
), |
||||||
|
transforms.ToTensor(), |
||||||
|
transforms.Normalize( |
||||||
|
(0.48145466, 0.4578275, 0.40821073), |
||||||
|
(0.26862954, 0.26130258, 0.27577711), |
||||||
|
), |
||||||
|
] |
||||||
|
)(pil_image) |
||||||
|
.unsqueeze(0) |
||||||
|
.to(device) |
||||||
|
) |
||||||
|
|
||||||
|
with torch.no_grad(): |
||||||
|
caption = blip_model.generate( |
||||||
|
gpu_image, sample=False, num_beams=3, max_length=20, min_length=5 |
||||||
|
) |
||||||
|
return caption[0] |
||||||
|
|
||||||
|
|
||||||
|
def rank_top(image_features, text_array, clip_model, device="cuda"): |
||||||
|
text_tokens = clip.tokenize([text for text in text_array]).cuda() |
||||||
|
with torch.no_grad(): |
||||||
|
text_features = clip_model.encode_text(text_tokens).float() |
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||||
|
|
||||||
|
similarity = torch.zeros((1, len(text_array)), device=device) |
||||||
|
for i in range(image_features.shape[0]): |
||||||
|
similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) |
||||||
|
|
||||||
|
_, top_labels = similarity.cpu().topk(1, dim=-1) |
||||||
|
return text_array[top_labels[0][0].numpy()] |
||||||
|
|
||||||
|
|
||||||
|
def similarity(image_features, text, clip_model): |
||||||
|
text_tokens = clip.tokenize([text]).cuda() |
||||||
|
with torch.no_grad(): |
||||||
|
text_features = clip_model.encode_text(text_tokens).float() |
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||||
|
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T |
||||||
|
return similarity[0][0] |
||||||
|
|
||||||
|
|
||||||
|
def load_list(filename): |
||||||
|
with open(filename, "r", encoding="utf-8", errors="replace") as f: |
||||||
|
items = [line.strip() for line in f.readlines()] |
||||||
|
return items |
||||||
|
|
||||||
|
|
||||||
|
def interrogate(image, clip_model_name, clip_preprocess, clip_model, blip_model, *args): |
||||||
|
flavors, mediums, artists, trendings, movements = args |
||||||
|
caption = generate_caption(image, blip_model) |
||||||
|
|
||||||
|
images = clip_preprocess(image).unsqueeze(0).cuda() |
||||||
|
with torch.no_grad(): |
||||||
|
image_features = clip_model.encode_image(images).float() |
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
||||||
|
|
||||||
|
flaves = flavors.rank(image_features, flavor_intermediate_count) |
||||||
|
best_medium = mediums.rank(image_features, 1)[0] |
||||||
|
best_artist = artists.rank(image_features, 1)[0] |
||||||
|
best_trending = trendings.rank(image_features, 1)[0] |
||||||
|
best_movement = movements.rank(image_features, 1)[0] |
||||||
|
|
||||||
|
best_prompt = caption |
||||||
|
best_sim = similarity(image_features, best_prompt, clip_model) |
||||||
|
|
||||||
|
def check(addition): |
||||||
|
nonlocal best_prompt, best_sim |
||||||
|
prompt = best_prompt + ", " + addition |
||||||
|
sim = similarity(image_features, prompt, clip_model) |
||||||
|
if sim > best_sim: |
||||||
|
best_sim = sim |
||||||
|
best_prompt = prompt |
||||||
|
return True |
||||||
|
return False |
||||||
|
|
||||||
|
def check_multi_batch(opts): |
||||||
|
nonlocal best_prompt, best_sim |
||||||
|
prompts = [] |
||||||
|
for i in range(2 ** len(opts)): |
||||||
|
prompt = best_prompt |
||||||
|
for bit in range(len(opts)): |
||||||
|
if i & (1 << bit): |
||||||
|
prompt += ", " + opts[bit] |
||||||
|
prompts.append(prompt) |
||||||
|
|
||||||
|
t = LabelTable(prompts, None, clip_model_name, clip_model) |
||||||
|
best_prompt = t.rank(image_features, 1)[0] |
||||||
|
best_sim = similarity(image_features, best_prompt, clip_model) |
||||||
|
|
||||||
|
check_multi_batch([best_medium, best_artist, best_trending, best_movement]) |
||||||
|
|
||||||
|
extended_flavors = set(flaves) |
||||||
|
for _ in tqdm(range(25), desc="Flavor chain"): |
||||||
|
try: |
||||||
|
best = rank_top( |
||||||
|
image_features, |
||||||
|
[f"{best_prompt}, {f}" for f in extended_flavors], |
||||||
|
clip_model, |
||||||
|
) |
||||||
|
flave = best[len(best_prompt) + 2 :] |
||||||
|
if not check(flave): |
||||||
|
break |
||||||
|
extended_flavors.remove(flave) |
||||||
|
except: |
||||||
|
# exceeded max prompt length |
||||||
|
break |
||||||
|
|
||||||
|
return best_prompt |
Loading…
Reference in new issue