Browse Source

Merge branch 'main' into main

pull/46/head
bolshoytoster 2 years ago committed by GitHub
parent
commit
96fc82f5a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      README.md
  2. 2
      clip_interrogator/__init__.py
  3. 108
      clip_interrogator/clip_interrogator.py
  4. 3
      requirements.txt
  5. 7
      run_gradio.py
  6. 4
      setup.py

2
README.md

@ -88,7 +88,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator
pip install clip-interrogator==0.4.3
pip install clip-interrogator==0.5.1
```
You can then use it in your script

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config
__version__ = '0.4.3'
__version__ = '0.5.1'
__author__ = 'pharmapsychotic'

108
clip_interrogator/clip_interrogator.py

@ -17,25 +17,29 @@ from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from typing import List, Union
from safetensors.numpy import load_file, save_file
BLIP_MODELS = {
"base": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth",
"large": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
}
CACHE_URLS_VITL = [
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl",
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_negative.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.safetensors',
]
CACHE_URLS_VITH = [
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl",
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_negative.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.safetensors',
]
@ -97,7 +101,8 @@ class Interrogator:
med_config=med_config,
)
blip_model.eval()
blip_model = blip_model.to(config.device)
if not self.config.blip_offload:
blip_model = blip_model.to(config.device)
self.blip_model = blip_model
else:
self.blip_model = config.blip_model
@ -558,23 +563,8 @@ class LabelTable:
self.tokenize = tokenize
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
cache_filepath = None
if config.cache_path is not None and desc is not None:
os.makedirs(config.cache_path, exist_ok=True)
sanitized_name = config.clip_model_name.replace("/", "_").replace("@", "_")
cache_filepath = os.path.join(
config.cache_path, f"{sanitized_name}_{desc}.pkl"
)
if desc is not None and os.path.exists(cache_filepath):
with open(cache_filepath, "rb") as f:
try:
data = pickle.load(f)
if data.get("hash") == hash:
self.labels = data["labels"]
self.embeds = data["embeds"]
except Exception as e:
print(f"Error loading cached table {desc}: {e}")
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_')
self._load_cached(desc, hash, sanitized_name)
if len(self.labels) != len(self.embeds):
self.embeds = []
@ -594,28 +584,50 @@ class LabelTable:
for i in range(text_features.shape[0]):
self.embeds.append(text_features[i])
if cache_filepath is not None:
with open(cache_filepath, "wb") as f:
pickle.dump(
{
"labels": self.labels,
"embeds": self.embeds,
"hash": hash,
"model": config.clip_model_name,
},
f,
)
if self.device == "cpu" or self.device == torch.device("cpu"):
if desc and self.config.cache_path:
os.makedirs(self.config.cache_path, exist_ok=True)
cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
tensors = {
"embeds": np.stack(self.embeds),
"hash": np.array([ord(c) for c in hash], dtype=np.int8)
}
save_file(tensors, cache_filepath)
if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds]
def _rank(
self,
image_features: torch.Tensor,
text_embeds: torch.Tensor,
top_count: int = 1,
reverse: bool = False,
) -> str:
def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool:
if self.config.cache_path is None or desc is None:
return False
# load from old pkl format if it exists
cached_pkl = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.pkl")
if os.path.exists(cached_pkl):
with open(cached_pkl, 'rb') as f:
try:
data = pickle.load(f)
if data.get('hash') == hash:
self.labels = data['labels']
self.embeds = data['embeds']
return True
except Exception as e:
print(f"Error loading cached table {desc}: {e}")
# load from new safetensors format if it exists
cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
if os.path.exists(cached_safetensors):
tensors = load_file(cached_safetensors)
if 'hash' in tensors and 'embeds' in tensors:
if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)):
self.embeds = tensors['embeds']
if len(self.embeds.shape) == 2:
self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])]
return True
return False
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(
self.device

3
requirements.txt

@ -2,6 +2,7 @@ torch
torchvision
Pillow
requests
safetensors
tqdm
open_clip_torch
blip-vit
blip-ci

7
run_gradio.py

@ -1,10 +1,15 @@
#!/usr/bin/env python3
import argparse
import gradio as gr
import open_clip
import torch
from clip_interrogator import Config, Interrogator
try:
import gradio as gr
except ImportError:
print("Gradio is not installed, please install it with 'pip install gradio'")
exit(1)
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
args = parser.parse_args()

4
setup.py

@ -5,13 +5,13 @@ from setuptools import setup, find_packages
setup(
name="clip-interrogator",
version="0.4.3",
version="0.5.1",
license='MIT',
author='pharmapsychotic',
author_email='me@pharmapsychotic.com',
url='https://github.com/pharmapsychotic/clip-interrogator',
description="Generate a prompt from an image",
long_description=open('README.md').read(),
long_description=open('README.md', encoding='utf-8').read(),
long_description_content_type="text/markdown",
packages=find_packages(),
install_requires=[

Loading…
Cancel
Save