diff --git a/README.md b/README.md index 7dcefc8..7608dac 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Install with PIP pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 # install clip-interrogator -pip install clip-interrogator==0.5.1 +pip install clip-interrogator==0.5.2 ``` You can then use it in your script @@ -60,4 +60,6 @@ The `Config` object lets you configure CLIP Interrogator's processing. * `chunk_size`: batch size for CLIP, use smaller for lower VRAM * `quiet`: when True no progress bars or text output will be displayed +On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB. + See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes. diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 28fc2c6..27bb6f9 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ from .clip_interrogator import Interrogator, Config -__version__ = '0.5.1' +__version__ = '0.5.2' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 8a92b6e..f0fca31 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -4,7 +4,6 @@ import math import numpy as np import open_clip import os -import pickle import requests import time import torch @@ -15,7 +14,7 @@ from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm -from typing import List +from typing import List, Optional from safetensors.numpy import load_file, save_file @@ -24,23 +23,7 @@ BLIP_MODELS = { 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' } -CACHE_URLS_VITL = [ - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_negative.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.safetensors', -] - -CACHE_URLS_VITH = [ - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_negative.safetensors', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.safetensors', -] +CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/' @dataclass @@ -60,6 +43,7 @@ class Config: # clip settings clip_model_name: str = 'ViT-L-14/openai' clip_model_path: str = None + clip_offload: bool = False # interrogator settings cache_path: str = 'cache' # path to store cached text embeddings @@ -70,11 +54,19 @@ class Config: flavor_intermediate_count: int = 2048 quiet: bool = False # when quiet progress bars are not shown + def apply_low_vram_defaults(self): + self.blip_model_type = 'base' + self.blip_offload = True + self.clip_offload = True + self.chunk_size = 1024 + self.flavor_intermediate_count = 1024 class Interrogator(): def __init__(self, config: Config): self.config = config self.device = config.device + self.blip_offloaded = True + self.clip_offloaded = True if config.blip_model is None: if not config.quiet: @@ -97,21 +89,6 @@ class Interrogator(): self.load_clip_model() - def download_cache(self, clip_model_name: str): - if clip_model_name == 'ViT-L-14/openai': - cache_urls = CACHE_URLS_VITL - elif clip_model_name == 'ViT-H-14/laion2b_s32b_b79k': - cache_urls = CACHE_URLS_VITH - else: - # text embeddings will be precomputed and cached locally - return - - os.makedirs(self.config.cache_path, exist_ok=True) - for url in cache_urls: - filepath = os.path.join(self.config.cache_path, url.split('/')[-1]) - if not os.path.exists(filepath): - _download_file(url, filepath, quiet=self.config.quiet) - def load_clip_model(self): start_time = time.time() config = self.config @@ -129,13 +106,15 @@ class Interrogator(): jit=False, cache_dir=config.clip_model_path ) - self.clip_model.to(config.device).eval() + self.clip_model.eval() else: self.clip_model = config.clip_model self.clip_preprocess = config.clip_preprocess self.tokenize = open_clip.get_tokenizer(clip_model_name) - sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] + sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', + 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', + 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] trending_list = [site for site in sites] trending_list.extend(["trending on "+site for site in sites]) trending_list.extend(["featured on "+site for site in sites]) @@ -145,9 +124,7 @@ class Interrogator(): artists = [f"by {a}" for a in raw_artists] artists.extend([f"inspired by {a}" for a in raw_artists]) - if config.download_cache: - self.download_cache(config.clip_model_name) - + self._prepare_clip() self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) @@ -170,6 +147,8 @@ class Interrogator(): desc="Chaining", reverse: bool=False ) -> str: + self._prepare_clip() + phrases = set(phrases) if not best_prompt: best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse) @@ -203,8 +182,8 @@ class Interrogator(): return best_prompt def generate_caption(self, pil_image: Image) -> str: - if self.config.blip_offload: - self.blip_model = self.blip_model.to(self.device) + self._prepare_blip() + size = self.config.blip_image_eval_size gpu_image = transforms.Compose([ transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), @@ -220,21 +199,21 @@ class Interrogator(): max_length=self.config.blip_max_length, min_length=5 ) - if self.config.blip_offload: - self.blip_model = self.blip_model.to("cpu") + return caption[0] def image_to_features(self, image: Image) -> torch.Tensor: + self._prepare_clip() images = self.clip_preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): image_features = self.clip_model.encode_image(images) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features - def interrogate_classic(self, image: Image, max_flavors: int=3) -> str: + def interrogate_classic(self, image: Image, max_flavors: int=3, caption: Optional[str]=None) -> str: """Classic mode creates a prompt in a standard format first describing the image, then listing the artist, trending, movement, and flavor text modifiers.""" - caption = self.generate_caption(image) + caption = caption or self.generate_caption(image) image_features = self.image_to_features(image) medium = self.mediums.rank(image_features, 1)[0] @@ -250,11 +229,11 @@ class Interrogator(): return _truncate_to_fit(prompt, self.tokenize) - def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: + def interrogate_fast(self, image: Image, max_flavors: int=32, caption: Optional[str]=None) -> str: """Fast mode simply adds the top ranked terms after a caption. It generally results in better similarity between generated prompt and image than classic mode, but the prompts are less readable.""" - caption = self.generate_caption(image) + caption = caption or self.generate_caption(image) image_features = self.image_to_features(image) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) tops = merged.rank(image_features, max_flavors) @@ -269,22 +248,22 @@ class Interrogator(): flaves = flaves + self.negative.labels return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain") - def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32) -> str: - caption = self.generate_caption(image) + def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32, caption: Optional[str]=None) -> str: + caption = caption or self.generate_caption(image) image_features = self.image_to_features(image) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) flaves = merged.rank(image_features, self.config.flavor_intermediate_count) - best_prompt, best_sim = caption, self.similarity(image_features, caption) best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain") - fast_prompt = self.interrogate_fast(image, max_flavors) - classic_prompt = self.interrogate_classic(image, max_flavors) + fast_prompt = self.interrogate_fast(image, max_flavors, caption=caption) + classic_prompt = self.interrogate_classic(image, max_flavors, caption=caption) candidates = [caption, classic_prompt, fast_prompt, best_prompt] return candidates[np.argmax(self.similarities(image_features, candidates))] def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: + self._prepare_clip() text_tokens = self.tokenize([text for text in text_array]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) @@ -295,6 +274,7 @@ class Interrogator(): return text_array[similarity.argmax().item()] def similarity(self, image_features: torch.Tensor, text: str) -> float: + self._prepare_clip() text_tokens = self.tokenize([text]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) @@ -303,6 +283,7 @@ class Interrogator(): return similarity[0][0].item() def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]: + self._prepare_clip() text_tokens = self.tokenize([text for text in text_array]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) @@ -310,6 +291,22 @@ class Interrogator(): similarity = text_features @ image_features.T return similarity.T[0].tolist() + def _prepare_blip(self): + if self.config.clip_offload and not self.clip_offloaded: + self.clip_model = self.clip_model.to('cpu') + self.clip_offloaded = True + if self.blip_offloaded: + self.blip_model = self.blip_model.to(self.device) + self.blip_offloaded = False + + def _prepare_clip(self): + if self.config.blip_offload: + self.blip_model = self.blip_model.to('cpu') + self.blip_offloaded = True + if self.clip_offloaded: + self.clip_model = self.clip_model.to(self.device) + self.clip_offloaded = False + class LabelTable(): def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): @@ -352,23 +349,25 @@ class LabelTable(): if self.config.cache_path is None or desc is None: return False - # load from old pkl format if it exists - cached_pkl = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.pkl") - if os.path.exists(cached_pkl): - with open(cached_pkl, 'rb') as f: - try: - data = pickle.load(f) - if data.get('hash') == hash: - self.labels = data['labels'] - self.embeds = data['embeds'] - return True - except Exception as e: - print(f"Error loading cached table {desc}: {e}") - - # load from new safetensors format if it exists cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") + + if self.config.download_cache and not os.path.exists(cached_safetensors): + download_url = CACHE_URL_BASE + f"{sanitized_name}_{desc}.safetensors" + try: + os.makedirs(self.config.cache_path, exist_ok=True) + _download_file(download_url, cached_safetensors, quiet=self.config.quiet) + except Exception as e: + print(f"Failed to download {download_url}") + print(e) + return False + if os.path.exists(cached_safetensors): - tensors = load_file(cached_safetensors) + try: + tensors = load_file(cached_safetensors) + except Exception as e: + print(f"Failed to load {cached_safetensors}") + print(e) + return False if 'hash' in tensors and 'embeds' in tensors: if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)): self.embeds = tensors['embeds'] @@ -377,7 +376,6 @@ class LabelTable(): return True return False - def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: top_count = min(top_count, len(text_embeds)) @@ -409,8 +407,11 @@ class LabelTable(): return [top_labels[i] for i in tops] -def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bool = False): +def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet: bool = False): r = requests.get(url, stream=True) + if r.status_code != 200: + return + file_size = int(r.headers.get("Content-Length", 0)) filename = url.split("/")[-1] progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet) diff --git a/run_cli.py b/run_cli.py index fdbba04..59d563d 100755 --- a/run_cli.py +++ b/run_cli.py @@ -24,6 +24,7 @@ def main(): parser.add_argument('-f', '--folder', help='path to folder of images') parser.add_argument('-i', '--image', help='image file or url') parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') + parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM") args = parser.parse_args() if not args.folder and not args.image: @@ -51,6 +52,8 @@ def main(): # generate a nice prompt config = Config(device=device, clip_model_name=args.clip) + if args.lowvram: + config.apply_low_vram_defaults() ci = Interrogator(config) # process single image diff --git a/run_gradio.py b/run_gradio.py index 64568ba..938171a 100755 --- a/run_gradio.py +++ b/run_gradio.py @@ -11,13 +11,17 @@ except ImportError: exit(1) parser = argparse.ArgumentParser() +parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM") parser.add_argument('-s', '--share', action='store_true', help='Create a public link') args = parser.parse_args() if not torch.cuda.is_available(): print("CUDA is not available, using CPU. Warning: this will be very slow!") -ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) +config = Config(cache_path="cache") +if args.lowvram: + config.apply_low_vram_defaults() +ci = Interrogator(config) def image_analysis(image, clip_model_name): if clip_model_name != ci.config.clip_model_name: diff --git a/setup.py b/setup.py index 4b676f7..d6ce654 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name="clip-interrogator", - version="0.5.1", + version="0.5.2", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',