|
|
|
@ -1,5 +1,4 @@
|
|
|
|
|
import hashlib |
|
|
|
|
import inspect |
|
|
|
|
import math |
|
|
|
|
import numpy as np |
|
|
|
|
import open_clip |
|
|
|
@ -9,18 +8,19 @@ import time
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
from blip.models.blip import blip_decoder, BLIP_Decoder |
|
|
|
|
from PIL import Image |
|
|
|
|
from torchvision import transforms |
|
|
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
|
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
|
from safetensors.numpy import load_file, save_file |
|
|
|
|
|
|
|
|
|
BLIP_MODELS = { |
|
|
|
|
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', |
|
|
|
|
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
|
|
|
|
CAPTION_MODELS = { |
|
|
|
|
'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB |
|
|
|
|
'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB |
|
|
|
|
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB |
|
|
|
|
'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB |
|
|
|
|
'git-large-coco': 'microsoft/git-large-coco', # 1.58GB |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/' |
|
|
|
@ -29,16 +29,15 @@ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
|
|
|
|
|
@dataclass |
|
|
|
|
class Config: |
|
|
|
|
# models can optionally be passed in directly |
|
|
|
|
blip_model: Optional[BLIP_Decoder] = None |
|
|
|
|
caption_model = None |
|
|
|
|
caption_processor = None |
|
|
|
|
clip_model = None |
|
|
|
|
clip_preprocess = None |
|
|
|
|
|
|
|
|
|
# blip settings |
|
|
|
|
blip_image_eval_size: int = 384 |
|
|
|
|
blip_max_length: int = 32 |
|
|
|
|
blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None |
|
|
|
|
blip_num_beams: int = 8 |
|
|
|
|
blip_offload: bool = False |
|
|
|
|
caption_max_length: int = 32 |
|
|
|
|
caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None |
|
|
|
|
caption_offload: bool = False |
|
|
|
|
|
|
|
|
|
# clip settings |
|
|
|
|
clip_model_name: str = 'ViT-L-14/openai' |
|
|
|
@ -55,8 +54,8 @@ class Config:
|
|
|
|
|
quiet: bool = False # when quiet progress bars are not shown |
|
|
|
|
|
|
|
|
|
def apply_low_vram_defaults(self): |
|
|
|
|
self.blip_model_type = 'base' |
|
|
|
|
self.blip_offload = True |
|
|
|
|
self.caption_model_name = 'blip-base' |
|
|
|
|
self.caption_offload = True |
|
|
|
|
self.clip_offload = True |
|
|
|
|
self.chunk_size = 1024 |
|
|
|
|
self.flavor_intermediate_count = 1024 |
|
|
|
@ -65,29 +64,33 @@ class Interrogator():
|
|
|
|
|
def __init__(self, config: Config): |
|
|
|
|
self.config = config |
|
|
|
|
self.device = config.device |
|
|
|
|
self.blip_offloaded = True |
|
|
|
|
self.dtype = torch.float16 if self.device == 'cuda' else torch.float32 |
|
|
|
|
self.caption_offloaded = True |
|
|
|
|
self.clip_offloaded = True |
|
|
|
|
self.load_caption_model() |
|
|
|
|
self.load_clip_model() |
|
|
|
|
|
|
|
|
|
if config.blip_model is None and config.blip_model_type: |
|
|
|
|
if not config.quiet: |
|
|
|
|
print("Loading BLIP model...") |
|
|
|
|
blip_path = os.path.dirname(inspect.getfile(blip_decoder)) |
|
|
|
|
configs_path = os.path.join(os.path.dirname(blip_path), 'configs') |
|
|
|
|
med_config = os.path.join(configs_path, 'med_config.json') |
|
|
|
|
blip_model = blip_decoder( |
|
|
|
|
pretrained=BLIP_MODELS[config.blip_model_type], |
|
|
|
|
image_size=config.blip_image_eval_size, |
|
|
|
|
vit=config.blip_model_type, |
|
|
|
|
med_config=med_config |
|
|
|
|
) |
|
|
|
|
blip_model.eval() |
|
|
|
|
if not self.config.blip_offload: |
|
|
|
|
blip_model = blip_model.to(config.device) |
|
|
|
|
self.blip_model = blip_model |
|
|
|
|
def load_caption_model(self): |
|
|
|
|
if self.config.caption_model is None and self.config.caption_model_name: |
|
|
|
|
if not self.config.quiet: |
|
|
|
|
print(f"Loading caption model {self.config.caption_model_name}...") |
|
|
|
|
|
|
|
|
|
model_path = CAPTION_MODELS[self.config.caption_model_name] |
|
|
|
|
if self.config.caption_model_name.startswith('git-'): |
|
|
|
|
caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32) |
|
|
|
|
elif self.config.caption_model_name.startswith('blip2-'): |
|
|
|
|
caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype) |
|
|
|
|
else: |
|
|
|
|
caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype) |
|
|
|
|
self.caption_processor = AutoProcessor.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
|
caption_model.eval() |
|
|
|
|
if not self.config.caption_offload: |
|
|
|
|
caption_model = caption_model.to(self.config.device) |
|
|
|
|
self.caption_model = caption_model |
|
|
|
|
else: |
|
|
|
|
self.blip_model = config.blip_model |
|
|
|
|
|
|
|
|
|
self.load_clip_model() |
|
|
|
|
self.caption_model = self.config.caption_model |
|
|
|
|
self.caption_processor = self.config.caption_processor |
|
|
|
|
|
|
|
|
|
def load_clip_model(self): |
|
|
|
|
start_time = time.time() |
|
|
|
@ -97,7 +100,7 @@ class Interrogator():
|
|
|
|
|
|
|
|
|
|
if config.clip_model is None: |
|
|
|
|
if not config.quiet: |
|
|
|
|
print("Loading CLIP model...") |
|
|
|
|
print(f"Loading CLIP model {config.clip_model_name}...") |
|
|
|
|
|
|
|
|
|
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( |
|
|
|
|
clip_model_name, |
|
|
|
@ -183,26 +186,13 @@ class Interrogator():
|
|
|
|
|
return best_prompt |
|
|
|
|
|
|
|
|
|
def generate_caption(self, pil_image: Image) -> str: |
|
|
|
|
assert self.blip_model is not None, "No BLIP model loaded." |
|
|
|
|
self._prepare_blip() |
|
|
|
|
|
|
|
|
|
size = self.config.blip_image_eval_size |
|
|
|
|
gpu_image = transforms.Compose([ |
|
|
|
|
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), |
|
|
|
|
transforms.ToTensor(), |
|
|
|
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
|
])(pil_image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
caption = self.blip_model.generate( |
|
|
|
|
gpu_image, |
|
|
|
|
sample=False, |
|
|
|
|
num_beams=self.config.blip_num_beams, |
|
|
|
|
max_length=self.config.blip_max_length, |
|
|
|
|
min_length=5 |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return caption[0] |
|
|
|
|
assert self.caption_model is not None, "No caption model loaded." |
|
|
|
|
self._prepare_caption() |
|
|
|
|
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device) |
|
|
|
|
if not self.config.caption_model_name.startswith('git-'): |
|
|
|
|
inputs = inputs.to(self.dtype) |
|
|
|
|
tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length) |
|
|
|
|
return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip() |
|
|
|
|
|
|
|
|
|
def image_to_features(self, image: Image) -> torch.Tensor: |
|
|
|
|
self._prepare_clip() |
|
|
|
@ -237,7 +227,7 @@ class Interrogator():
|
|
|
|
|
are less readable.""" |
|
|
|
|
caption = caption or self.generate_caption(image) |
|
|
|
|
image_features = self.image_to_features(image) |
|
|
|
|
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) |
|
|
|
|
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self) |
|
|
|
|
tops = merged.rank(image_features, max_flavors) |
|
|
|
|
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) |
|
|
|
|
|
|
|
|
@ -254,7 +244,7 @@ class Interrogator():
|
|
|
|
|
caption = caption or self.generate_caption(image) |
|
|
|
|
image_features = self.image_to_features(image) |
|
|
|
|
|
|
|
|
|
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) |
|
|
|
|
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self) |
|
|
|
|
flaves = merged.rank(image_features, self.config.flavor_intermediate_count) |
|
|
|
|
best_prompt, best_sim = caption, self.similarity(image_features, caption) |
|
|
|
|
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain") |
|
|
|
@ -293,18 +283,18 @@ class Interrogator():
|
|
|
|
|
similarity = text_features @ image_features.T |
|
|
|
|
return similarity.T[0].tolist() |
|
|
|
|
|
|
|
|
|
def _prepare_blip(self): |
|
|
|
|
def _prepare_caption(self): |
|
|
|
|
if self.config.clip_offload and not self.clip_offloaded: |
|
|
|
|
self.clip_model = self.clip_model.to('cpu') |
|
|
|
|
self.clip_offloaded = True |
|
|
|
|
if self.blip_offloaded: |
|
|
|
|
self.blip_model = self.blip_model.to(self.device) |
|
|
|
|
self.blip_offloaded = False |
|
|
|
|
if self.caption_offloaded: |
|
|
|
|
self.caption_model = self.caption_model.to(self.device) |
|
|
|
|
self.caption_offloaded = False |
|
|
|
|
|
|
|
|
|
def _prepare_clip(self): |
|
|
|
|
if self.config.blip_offload and not self.blip_offloaded: |
|
|
|
|
self.blip_model = self.blip_model.to('cpu') |
|
|
|
|
self.blip_offloaded = True |
|
|
|
|
if self.config.caption_offload and not self.caption_offloaded: |
|
|
|
|
self.caption_model = self.caption_model.to('cpu') |
|
|
|
|
self.caption_offloaded = True |
|
|
|
|
if self.clip_offloaded: |
|
|
|
|
self.clip_model = self.clip_model.to(self.device) |
|
|
|
|
self.clip_offloaded = False |
|
|
|
@ -425,8 +415,8 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
|
|
|
|
|
progress.update(len(chunk)) |
|
|
|
|
progress.close() |
|
|
|
|
|
|
|
|
|
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: |
|
|
|
|
m = LabelTable([], None, None, None, config) |
|
|
|
|
def _merge_tables(tables: List[LabelTable], ci: Interrogator) -> LabelTable: |
|
|
|
|
m = LabelTable([], None, ci) |
|
|
|
|
for table in tables: |
|
|
|
|
m.labels.extend(table.labels) |
|
|
|
|
m.embeds.extend(table.embeds) |
|
|
|
@ -445,6 +435,12 @@ def _truncate_to_fit(text: str, tokenize) -> str:
|
|
|
|
|
new_text += ', ' + part |
|
|
|
|
return new_text |
|
|
|
|
|
|
|
|
|
def list_caption_models() -> List[str]: |
|
|
|
|
return list(CAPTION_MODELS.keys()) |
|
|
|
|
|
|
|
|
|
def list_clip_models() -> List[str]: |
|
|
|
|
return ['/'.join(x) for x in open_clip.list_pretrained()] |
|
|
|
|
|
|
|
|
|
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]: |
|
|
|
|
"""Load a list of strings from a file.""" |
|
|
|
|
if filename is not None: |
|
|
|
|