From ae88b07a65232ae3dac9af3953b0dbe08a8987f2 Mon Sep 17 00:00:00 2001 From: pharmapsychotic Date: Sat, 18 Feb 2023 14:53:02 -0600 Subject: [PATCH] safetensors! - store cached embeddings in safetensor format - updated huggingface ci-preprocess repo - bumped version to 0.5.0 --- README.md | 2 +- clip_interrogator/__init__.py | 2 +- clip_interrogator/clip_interrogator.py | 86 ++++++++++++++++---------- requirements.txt | 1 + run_gradio.py | 7 ++- setup.py | 2 +- 6 files changed, 63 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index bceb99b..28d5c2c 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Install with PIP pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 # install clip-interrogator -pip install clip-interrogator==0.4.4 +pip install clip-interrogator==0.5.0 ``` You can then use it in your script diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index f85ef1e..0992c67 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ from .clip_interrogator import Interrogator, Config -__version__ = '0.4.4' +__version__ = '0.5.0' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 97d32e8..c1d8f99 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -17,25 +17,27 @@ from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from typing import List +from safetensors.numpy import load_file, save_file + BLIP_MODELS = { 'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' } CACHE_URLS_VITL = [ - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.safetensors', ] CACHE_URLS_VITH = [ - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl', - 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.safetensors', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.safetensors', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.safetensors', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.safetensors', + 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.safetensors', ] @@ -316,21 +318,8 @@ class LabelTable(): self.tokenize = tokenize hash = hashlib.sha256(",".join(labels).encode()).hexdigest() - - cache_filepath = None - if config.cache_path is not None and desc is not None: - os.makedirs(config.cache_path, exist_ok=True) - sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_') - cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl") - if desc is not None and os.path.exists(cache_filepath): - with open(cache_filepath, 'rb') as f: - try: - data = pickle.load(f) - if data.get('hash') == hash: - self.labels = data['labels'] - self.embeds = data['embeds'] - except Exception as e: - print(f"Error loading cached table {desc}: {e}") + sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_') + self._load_cached(desc, hash, sanitized_name) if len(self.labels) != len(self.embeds): self.embeds = [] @@ -344,17 +333,48 @@ class LabelTable(): for i in range(text_features.shape[0]): self.embeds.append(text_features[i]) - if cache_filepath is not None: - with open(cache_filepath, 'wb') as f: - pickle.dump({ - "labels": self.labels, - "embeds": self.embeds, - "hash": hash, - "model": config.clip_model_name - }, f) + if desc and self.config.cache_path: + os.makedirs(self.config.cache_path, exist_ok=True) + cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") + tensors = { + "embeds": np.stack(self.embeds), + "hash": np.array([ord(c) for c in hash], dtype=np.int8) + } + save_file(tensors, cache_filepath) if self.device == 'cpu' or self.device == torch.device('cpu'): self.embeds = [e.astype(np.float32) for e in self.embeds] + + def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool: + if self.config.cache_path is None or desc is None: + return False + + # load from old pkl format if it exists + cached_pkl = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.pkl") + if os.path.exists(cached_pkl): + with open(cached_pkl, 'rb') as f: + try: + data = pickle.load(f) + if data.get('hash') == hash: + self.labels = data['labels'] + self.embeds = data['embeds'] + return True + except Exception as e: + print(f"Error loading cached table {desc}: {e}") + + # load from new safetensors format if it exists + cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") + if os.path.exists(cached_safetensors): + tensors = load_file(cached_safetensors) + if 'hash' in tensors and 'embeds' in tensors: + if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)): + self.embeds = tensors['embeds'] + if len(self.embeds.shape) == 2: + self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])] + return True + + return False + def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: top_count = min(top_count, len(text_embeds)) diff --git a/requirements.txt b/requirements.txt index 0eb27eb..6f4f9f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch torchvision Pillow requests +safetensors tqdm open_clip_torch blip-ci \ No newline at end of file diff --git a/run_gradio.py b/run_gradio.py index c8f1597..64568ba 100755 --- a/run_gradio.py +++ b/run_gradio.py @@ -1,10 +1,15 @@ #!/usr/bin/env python3 import argparse -import gradio as gr import open_clip import torch from clip_interrogator import Config, Interrogator +try: + import gradio as gr +except ImportError: + print("Gradio is not installed, please install it with 'pip install gradio'") + exit(1) + parser = argparse.ArgumentParser() parser.add_argument('-s', '--share', action='store_true', help='Create a public link') args = parser.parse_args() diff --git a/setup.py b/setup.py index e2fbd61..918cef9 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name="clip-interrogator", - version="0.4.4", + version="0.5.0", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',