import hashlib import inspect import math import numpy as np import open_clip import os import pickle import requests import time import torch from dataclasses import dataclass from blip.models.blip import blip_decoder, BLIP_Decoder from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from typing import List, Union from safetensors.numpy import load_file, save_file BLIP_MODELS = { "base": "", "large": "", } CACHE_URLS_VITL = [ '', '', '', '', '', '', ] CACHE_URLS_VITH = [ '', '', '', '', '', '', ] @dataclass class Config: # models can optionally be passed in directly blip_model: BLIP_Decoder = None clip_model = None clip_preprocess = None # blip settings blip_image_eval_size: int = 384 blip_max_length: int = 32 blip_model_type: str = "large" # choose between 'base' or 'large' blip_num_beams: int = 8 blip_offload: bool = False # clip settings clip_model_name: str = "ViT-L-14/openai" clip_model_path: str = None # interrogator settings cache_path: str = "cache" # path to store cached text embeddings download_cache: bool = ( True # when true, cached embeds are downloaded from huggingface ) chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM data_path: str = os.path.join(os.path.dirname(__file__), "data") device: Union[str, torch.device] = ( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) flavor_intermediate_count: int = 2048 quiet: bool = False # when quiet progress bars are not shown class Interrogator: def __init__(self, config: Config): self.config = config self.device = config.device # Record which model is on the target device self.blip_loaded = True # Load BLIP model (to intended device) if config.blip_model is None: if not config.quiet: print("Loading BLIP model...") blip_path = os.path.dirname(inspect.getfile(blip_decoder)) configs_path = os.path.join(os.path.dirname(blip_path), "configs") med_config = os.path.join(configs_path, "med_config.json") blip_model = blip_decoder( pretrained=BLIP_MODELS[config.blip_model_type], image_size=config.blip_image_eval_size, vit=config.blip_model_type, med_config=med_config, ) blip_model.eval() if not self.config.blip_offload: blip_model = self.blip_model = blip_model else: self.blip_model = config.blip_model # Load CLIP (to CPU) self.load_clip_model() def download_cache(self, clip_model_name: str): if clip_model_name == "ViT-L-14/openai": cache_urls = CACHE_URLS_VITL elif clip_model_name == "ViT-H-14/laion2b_s32b_b79k": cache_urls = CACHE_URLS_VITH else: # text embeddings will be precomputed and cached locally return os.makedirs(self.config.cache_path, exist_ok=True) for url in cache_urls: filepath = os.path.join(self.config.cache_path, url.split("/")[-1]) if not os.path.exists(filepath): _download_file(url, filepath, quiet=self.config.quiet) def load_clip_model(self): start_time = time.time() config = self.config if config.clip_model is None: if not config.quiet: print("Loading CLIP model...") clip_model_name, clip_model_pretrained_name = config.clip_model_name.split( "/", 2 ) ( self.clip_model, _, self.clip_preprocess, ) = open_clip.create_model_and_transforms( clip_model_name, pretrained=clip_model_pretrained_name, precision="fp16" if ( config.device.type if isinstance(config.device, torch.device) else config.device ) == "cuda" else "fp32", device="cpu", jit=False, cache_dir=config.clip_model_path, ) self.clip_model.eval() else: self.clip_model = config.clip_model self.clip_preprocess = config.clip_preprocess self.tokenize = open_clip.get_tokenizer(clip_model_name) sites = [ "Artstation", "behance", "cg society", "cgsociety", "deviantart", "dribble", "flickr", "instagram", "pexels", "pinterest", "pixabay", "pixiv", "polycount", "reddit", "shutterstock", "tumblr", "unsplash", "zbrush central", ] trending_list = [site for site in sites] trending_list.extend(["trending on " + site for site in sites]) trending_list.extend(["featured on " + site for site in sites]) trending_list.extend([site + " contest winner" for site in sites]) raw_artists = _load_list(config.data_path, "artists.txt") artists = [f"by {a}" for a in raw_artists] artists.extend([f"inspired by {a}" for a in raw_artists]) if config.download_cache: self.download_cache(config.clip_model_name) self.artists = LabelTable( artists, "artists", self.clip_model, self.tokenize, config ) self.flavors = LabelTable( _load_list(config.data_path, "flavors.txt"), "flavors", self.clip_model, self.tokenize, config, ) self.mediums = LabelTable( _load_list(config.data_path, "mediums.txt"), "mediums", self.clip_model, self.tokenize, config, ) self.movements = LabelTable( _load_list(config.data_path, "movements.txt"), "movements", self.clip_model, self.tokenize, config, ) self.trendings = LabelTable( trending_list, "trendings", self.clip_model, self.tokenize, config ) self.negative = LabelTable( _load_list(config.data_path, "negative.txt"), "negative", self.clip_model, self.tokenize, config, ) end_time = time.time() if not config.quiet: print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") def chain( self, image_features: torch.Tensor, phrases: List[str], best_prompt: str = "", best_sim: float = 0, min_count: int = 8, max_count: int = 32, desc="Chaining", reverse: bool = False, ) -> str: phrases = set(phrases) if not best_prompt: best_prompt = self.rank_top( image_features, [f for f in phrases], reverse=reverse ) best_sim = self.similarity(image_features, best_prompt) phrases.remove(best_prompt) curr_prompt, curr_sim = best_prompt, best_sim def check(addition: str, idx: int) -> bool: nonlocal best_prompt, best_sim, curr_prompt, curr_sim prompt = curr_prompt + ", " + addition sim = self.similarity(image_features, prompt) if reverse: sim = -sim if sim > best_sim: best_prompt, best_sim = prompt, sim if sim > curr_sim or idx < min_count: curr_prompt, curr_sim = prompt, sim return True return False for idx in tqdm(range(max_count), desc=desc, disable=self.config.quiet): best = self.rank_top( image_features, [f"{curr_prompt}, {f}" for f in phrases], reverse=reverse, ) flave = best[len(curr_prompt) + 2 :] if not check(flave, idx): break if _prompt_at_max_len(curr_prompt, self.tokenize): break phrases.remove(flave) return best_prompt def generate_caption(self, pil_image: Image) -> str: if self.config.blip_offload: self.blip_model = size = self.config.blip_image_eval_size gpu_image = ( transforms.Compose( [ transforms.Resize( (size, size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), transforms.Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ] )(pil_image) .unsqueeze(0) .to(self.device) ) with torch.no_grad(): caption = self.blip_model.generate( gpu_image, sample=False, num_beams=self.config.blip_num_beams, max_length=self.config.blip_max_length, min_length=5, ) if self.config.blip_offload: self.blip_model ="cpu") return caption[0] def image_to_features(self, image: Image) -> torch.Tensor: images = self.clip_preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): image_features = self.clip_model.encode_image(images) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features def _first_bit(self, image: Image) -> (str, torch.Tensor): if self.blip_loaded: caption = self.generate_caption(image) # Move BLIP to RAM"cpu") # Move CLIP to intended device image_features = self.image_to_features(image) else: # CLIP is loaded image_features = self.image_to_features(image) # Move CLIP to RAM"cpu") # Move BLIP to intended device caption = self.generate_caption(image) # Toggle `blip_loaded` self.blip_loaded ^= True return caption, image_features def _first_bit_batch(self, images: list[Image]) -> (list[str], list[torch.Tensor]): image_features: list[torch.Tensor] = [] if self.blip_loaded: captions = [self.generate_caption(img) for img in images] # Move BLIP to RAM"cpu") # Move CLIP to intended device image_features = [self.image_to_features(img) for img in images] else: # CLIP is loaded image_features = [self.image_to_features(img) for img in images] # Move CLIP to RAM"cpu") # Move BLIP to intended device captions = [self.generate_caption(img) for img in images] # Toggle `blip_loaded` self.blip_loaded ^= True return captions, image_features def _interrogate_classic( self, caption: str, image_features: torch.Tensor, max_flavours: int = 3 ) -> str: medium = self.mediums.rank(image_features, 1)[0] artist = self.artists.rank(image_features, 1)[0] trending = self.trendings.rank(image_features, 1)[0] movement = self.movements.rank(image_features, 1)[0] flaves = ", ".join(self.flavors.rank(image_features, max_flavours)) if caption.startswith(medium): prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}" else: prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" return _truncate_to_fit(prompt, self.tokenize) def interrogate_classic(self, image: Image, max_flavors: int = 3) -> str: """Classic mode creates a prompt in a standard format first describing the image, then listing the artist, trending, movement, and flavor text modifiers.""" caption, image_features = self._first_bit(image) return self._interrogate_classic(caption, image_features, max_flavors) def interrogate_classic_batch( self, images: list[Image], max_flavors: int = 3 ) -> list[str]: """Classic mode creates a prompt in a standard format first describing the image, then listing the artist, trending, movement, and flavor text modifiers. This function interrogates a batch of images (more efficient than doing it individually).""" captions, image_features = self._first_bit_batch(images) returns: list[str] = [ self._interrogate_classic(caption, feature, max_flavors) for caption, feature in zip(captions, image_features) ] return returns def _interrogate_fast( self, caption: str, image_features: torch.Tensor, max_flavours: int = 32 ) -> str: merged = _merge_tables( [self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config, ) tops = merged.rank(image_features, max_flavours) return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize) def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: """Fast mode simply adds the top ranked terms after a caption. It generally results in better similarity between generated prompt and image than classic mode, but the prompts are less readable.""" caption, image_features = self._first_bit(image) return self._interrogate_fast(caption, image_features, max_flavors) def interrogate_fast_batch(self, images: list[Image], max_flavors: int = 32) -> str: """Fast mode simply adds the top ranked terms after a caption. It generally results in better similarity between generated prompt and image than classic mode, but the prompts are less readable. This function interrogates a batch of images (more efficient than doing it individually).""" captions, image_features = self._first_bit_batch(images) returns: list[str] = [ self._interrogate_fast(caption, feature, max_flavors) for caption, feature in zip(captions, image_features) ] return returns def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str: """Negative mode chains together the most dissimilar terms to the image. It can be used to help build a negative prompt to pair with the regular positive prompt and often improve the results of generated images particularly with Stable Diffusion 2.""" if self.blip_loaded: # Move CLIP to intended device"cpu") self.blip_loaded = False image_features = self.image_to_features(image) flaves = self.flavors.rank( image_features, self.config.flavor_intermediate_count, reverse=True ) flaves = flaves + self.negative.labels return self.chain( image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain", ) def _interrogate( self, caption: str, image_features: torch.Tensor, min_flavours: int = 8, max_flavours: int = 32, ) -> str: merged = _merge_tables( [self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config, ) flaves = merged.rank(image_features, self.config.flavor_intermediate_count) best_prompt, best_sim = caption, self.similarity(image_features, caption) best_prompt = self.chain( image_features, flaves, best_prompt, best_sim, min_count=min_flavours, max_count=max_flavours, desc="Flavor chain", ) fast_prompt = self._interrogate_fast(caption, image_features, max_flavours) classic_prompt = self._interrogate_classic( caption, image_features, max_flavours ) candidates = [caption, classic_prompt, fast_prompt, best_prompt] return candidates[np.argmax(self.similarities(image_features, candidates))] def interrogate( self, image: Image, min_flavors: int = 8, max_flavors: int = 32 ) -> str: caption, image_features = self._first_bit(image) return self._interrogate(caption, image_features, min_flavors, max_flavors) def interrogate_batch( self, images: list[Image], min_flavors: int = 8, max_flavors: int = 32 ) -> list[str]: """This function interrogates a batch of images (more efficient than doing it individually).""" captions, image_features = self._first_bit_batch(images) returns: list[str] = [ self._interrogate(caption, features, min_flavors, max_flavors) for caption, features in zip(captions, image_features) ] return returns def rank_top( self, image_features: torch.Tensor, text_array: List[str], reverse: bool = False ) -> str: text_tokens = self.tokenize([text for text in text_array]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T if reverse: similarity = -similarity return text_array[similarity.argmax().item()] def similarity(self, image_features: torch.Tensor, text: str) -> float: text_tokens = self.tokenize([text]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T return similarity[0][0].item() def similarities( self, image_features: torch.Tensor, text_array: List[str] ) -> List[float]: text_tokens = self.tokenize([text for text in text_array]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T return similarity.T[0].tolist() class LabelTable: def __init__( self, labels: List[str], desc: str, clip_model, tokenize, config: Config ): self.chunk_size = config.chunk_size self.config = config self.device = config.device self.embeds = [] self.labels = labels self.tokenize = tokenize hash = hashlib.sha256(",".join(labels).encode()).hexdigest() sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_') self._load_cached(desc, hash, sanitized_name) if len(self.labels) != len(self.embeds): self.embeds = [] chunks = np.array_split( self.labels, max(1, len(self.labels) / config.chunk_size) ) for chunk in tqdm( chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet, ): text_tokens = self.tokenize(chunk).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) text_features = text_features.half().cpu().numpy() for i in range(text_features.shape[0]): self.embeds.append(text_features[i]) if desc and self.config.cache_path: os.makedirs(self.config.cache_path, exist_ok=True) cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") tensors = { "embeds": np.stack(self.embeds), "hash": np.array([ord(c) for c in hash], dtype=np.int8) } save_file(tensors, cache_filepath) if self.device == 'cpu' or self.device == torch.device('cpu'): self.embeds = [e.astype(np.float32) for e in self.embeds] def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool: if self.config.cache_path is None or desc is None: return False # load from old pkl format if it exists cached_pkl = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.pkl") if os.path.exists(cached_pkl): with open(cached_pkl, 'rb') as f: try: data = pickle.load(f) if data.get('hash') == hash: self.labels = data['labels'] self.embeds = data['embeds'] return True except Exception as e: print(f"Error loading cached table {desc}: {e}") # load from new safetensors format if it exists cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") if os.path.exists(cached_safetensors): tensors = load_file(cached_safetensors) if 'hash' in tensors and 'embeds' in tensors: if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)): self.embeds = tensors['embeds'] if len(self.embeds.shape) == 2: self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])] return True return False def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: top_count = min(top_count, len(text_embeds)) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to( self.device ) with torch.cuda.amp.autocast(): similarity = image_features @ text_embeds.T if reverse: similarity = -similarity _, top_labels = similarity.float().cpu().topk(top_count, dim=-1) return [top_labels[0][i].numpy() for i in range(top_count)] def rank( self, image_features: torch.Tensor, top_count: int = 1, reverse: bool = False ) -> List[str]: if len(self.labels) <= self.chunk_size: tops = self._rank( image_features, self.embeds, top_count=top_count, reverse=reverse ) return [self.labels[i] for i in tops] num_chunks = int(math.ceil(len(self.labels) / self.chunk_size)) keep_per_chunk = int(self.chunk_size / num_chunks) top_labels, top_embeds = [], [] for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): start = chunk_idx * self.chunk_size stop = min(start + self.chunk_size, len(self.embeds)) tops = self._rank( image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse, ) top_labels.extend([self.labels[start + i] for i in tops]) top_embeds.extend([self.embeds[start + i] for i in tops]) tops = self._rank(image_features, top_embeds, top_count=top_count) return [top_labels[i] for i in tops] def _download_file( url: str, filepath: str, chunk_size: int = 64 * 1024, quiet: bool = False ): r = requests.get(url, stream=True) file_size = int(r.headers.get("Content-Length", 0)) filename = url.split("/")[-1] progress = tqdm( total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet ) with open(filepath, "wb") as f: for chunk in r.iter_content(chunk_size=chunk_size): if chunk: f.write(chunk) progress.update(len(chunk)) progress.close() def _load_list(data_path: str, filename: str) -> List[str]: with open( os.path.join(data_path, filename), "r", encoding="utf-8", errors="replace" ) as f: items = [line.strip() for line in f.readlines()] return items def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: m = LabelTable([], None, None, None, config) for table in tables: m.labels.extend(table.labels) m.embeds.extend(table.embeds) return m def _prompt_at_max_len(text: str, tokenize) -> bool: tokens = tokenize([text]) return tokens[0][-1] != 0 def _truncate_to_fit(text: str, tokenize) -> str: parts = text.split(", ") new_text = parts[0] for part in parts[1:]: if _prompt_at_max_len(new_text + part, tokenize): break new_text += ", " + part return new_text