Browse Source

safetensors!

- store cached embeddings in safetensor format
- updated huggingface ci-preprocess repo
- bumped version to 0.5.0
pull/46/merge
pharmapsychotic 2 years ago
parent
commit
ae88b07a65
  1. 2
      README.md
  2. 2
      clip_interrogator/__init__.py
  3. 86
      clip_interrogator/clip_interrogator.py
  4. 1
      requirements.txt
  5. 7
      run_gradio.py
  6. 2
      setup.py

2
README.md

@ -36,7 +36,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator
pip install clip-interrogator==0.4.4
pip install clip-interrogator==0.5.0
```
You can then use it in your script

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config
__version__ = '0.4.4'
__version__ = '0.5.0'
__author__ = 'pharmapsychotic'

86
clip_interrogator/clip_interrogator.py

@ -17,25 +17,27 @@ from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from typing import List
from safetensors.numpy import load_file, save_file
BLIP_MODELS = {
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
}
CACHE_URLS_VITL = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.safetensors',
]
CACHE_URLS_VITH = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.safetensors',
]
@ -316,21 +318,8 @@ class LabelTable():
self.tokenize = tokenize
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
cache_filepath = None
if config.cache_path is not None and desc is not None:
os.makedirs(config.cache_path, exist_ok=True)
sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_')
cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl")
if desc is not None and os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f:
try:
data = pickle.load(f)
if data.get('hash') == hash:
self.labels = data['labels']
self.embeds = data['embeds']
except Exception as e:
print(f"Error loading cached table {desc}: {e}")
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_')
self._load_cached(desc, hash, sanitized_name)
if len(self.labels) != len(self.embeds):
self.embeds = []
@ -344,17 +333,48 @@ class LabelTable():
for i in range(text_features.shape[0]):
self.embeds.append(text_features[i])
if cache_filepath is not None:
with open(cache_filepath, 'wb') as f:
pickle.dump({
"labels": self.labels,
"embeds": self.embeds,
"hash": hash,
"model": config.clip_model_name
}, f)
if desc and self.config.cache_path:
os.makedirs(self.config.cache_path, exist_ok=True)
cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
tensors = {
"embeds": np.stack(self.embeds),
"hash": np.array([ord(c) for c in hash], dtype=np.int8)
}
save_file(tensors, cache_filepath)
if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds]
def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool:
if self.config.cache_path is None or desc is None:
return False
# load from old pkl format if it exists
cached_pkl = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.pkl")
if os.path.exists(cached_pkl):
with open(cached_pkl, 'rb') as f:
try:
data = pickle.load(f)
if data.get('hash') == hash:
self.labels = data['labels']
self.embeds = data['embeds']
return True
except Exception as e:
print(f"Error loading cached table {desc}: {e}")
# load from new safetensors format if it exists
cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
if os.path.exists(cached_safetensors):
tensors = load_file(cached_safetensors)
if 'hash' in tensors and 'embeds' in tensors:
if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)):
self.embeds = tensors['embeds']
if len(self.embeds.shape) == 2:
self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])]
return True
return False
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
top_count = min(top_count, len(text_embeds))

1
requirements.txt

@ -2,6 +2,7 @@ torch
torchvision
Pillow
requests
safetensors
tqdm
open_clip_torch
blip-ci

7
run_gradio.py

@ -1,10 +1,15 @@
#!/usr/bin/env python3
import argparse
import gradio as gr
import open_clip
import torch
from clip_interrogator import Config, Interrogator
try:
import gradio as gr
except ImportError:
print("Gradio is not installed, please install it with 'pip install gradio'")
exit(1)
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
args = parser.parse_args()

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup(
name="clip-interrogator",
version="0.4.4",
version="0.5.0",
license='MIT',
author='pharmapsychotic',
author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save