Browse Source

More safetensor, download, and VRAM improvements

pull/52/head
pharmapsychotic 2 years ago
parent
commit
c4e16359a7
  1. 4
      README.md
  2. 2
      clip_interrogator/__init__.py
  3. 133
      clip_interrogator/clip_interrogator.py
  4. 3
      run_cli.py
  5. 6
      run_gradio.py
  6. 2
      setup.py

4
README.md

@ -36,7 +36,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator # install clip-interrogator
pip install clip-interrogator==0.5.1 pip install clip-interrogator==0.5.2
``` ```
You can then use it in your script You can then use it in your script
@ -60,4 +60,6 @@ The `Config` object lets you configure CLIP Interrogator's processing.
* `chunk_size`: batch size for CLIP, use smaller for lower VRAM * `chunk_size`: batch size for CLIP, use smaller for lower VRAM
* `quiet`: when True no progress bars or text output will be displayed * `quiet`: when True no progress bars or text output will be displayed
On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB.
See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes. See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes.

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config from .clip_interrogator import Interrogator, Config
__version__ = '0.5.1' __version__ = '0.5.2'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

133
clip_interrogator/clip_interrogator.py

@ -4,7 +4,6 @@ import math
import numpy as np import numpy as np
import open_clip import open_clip
import os import os
import pickle
import requests import requests
import time import time
import torch import torch
@ -15,7 +14,7 @@ from PIL import Image
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm from tqdm import tqdm
from typing import List from typing import List, Optional
from safetensors.numpy import load_file, save_file from safetensors.numpy import load_file, save_file
@ -24,23 +23,7 @@ BLIP_MODELS = {
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
} }
CACHE_URLS_VITL = [ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_negative.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.safetensors',
]
CACHE_URLS_VITH = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_negative.safetensors',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.safetensors',
]
@dataclass @dataclass
@ -60,6 +43,7 @@ class Config:
# clip settings # clip settings
clip_model_name: str = 'ViT-L-14/openai' clip_model_name: str = 'ViT-L-14/openai'
clip_model_path: str = None clip_model_path: str = None
clip_offload: bool = False
# interrogator settings # interrogator settings
cache_path: str = 'cache' # path to store cached text embeddings cache_path: str = 'cache' # path to store cached text embeddings
@ -70,11 +54,19 @@ class Config:
flavor_intermediate_count: int = 2048 flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown quiet: bool = False # when quiet progress bars are not shown
def apply_low_vram_defaults(self):
self.blip_model_type = 'base'
self.blip_offload = True
self.clip_offload = True
self.chunk_size = 1024
self.flavor_intermediate_count = 1024
class Interrogator(): class Interrogator():
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
self.device = config.device self.device = config.device
self.blip_offloaded = True
self.clip_offloaded = True
if config.blip_model is None: if config.blip_model is None:
if not config.quiet: if not config.quiet:
@ -97,21 +89,6 @@ class Interrogator():
self.load_clip_model() self.load_clip_model()
def download_cache(self, clip_model_name: str):
if clip_model_name == 'ViT-L-14/openai':
cache_urls = CACHE_URLS_VITL
elif clip_model_name == 'ViT-H-14/laion2b_s32b_b79k':
cache_urls = CACHE_URLS_VITH
else:
# text embeddings will be precomputed and cached locally
return
os.makedirs(self.config.cache_path, exist_ok=True)
for url in cache_urls:
filepath = os.path.join(self.config.cache_path, url.split('/')[-1])
if not os.path.exists(filepath):
_download_file(url, filepath, quiet=self.config.quiet)
def load_clip_model(self): def load_clip_model(self):
start_time = time.time() start_time = time.time()
config = self.config config = self.config
@ -129,13 +106,15 @@ class Interrogator():
jit=False, jit=False,
cache_dir=config.clip_model_path cache_dir=config.clip_model_path
) )
self.clip_model.to(config.device).eval() self.clip_model.eval()
else: else:
self.clip_model = config.clip_model self.clip_model = config.clip_model
self.clip_preprocess = config.clip_preprocess self.clip_preprocess = config.clip_preprocess
self.tokenize = open_clip.get_tokenizer(clip_model_name) self.tokenize = open_clip.get_tokenizer(clip_model_name)
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble',
'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount',
'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
trending_list = [site for site in sites] trending_list = [site for site in sites]
trending_list.extend(["trending on "+site for site in sites]) trending_list.extend(["trending on "+site for site in sites])
trending_list.extend(["featured on "+site for site in sites]) trending_list.extend(["featured on "+site for site in sites])
@ -145,9 +124,7 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists] artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists]) artists.extend([f"inspired by {a}" for a in raw_artists])
if config.download_cache: self._prepare_clip()
self.download_cache(config.clip_model_name)
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
@ -170,6 +147,8 @@ class Interrogator():
desc="Chaining", desc="Chaining",
reverse: bool=False reverse: bool=False
) -> str: ) -> str:
self._prepare_clip()
phrases = set(phrases) phrases = set(phrases)
if not best_prompt: if not best_prompt:
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse) best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
@ -203,8 +182,8 @@ class Interrogator():
return best_prompt return best_prompt
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
if self.config.blip_offload: self._prepare_blip()
self.blip_model = self.blip_model.to(self.device)
size = self.config.blip_image_eval_size size = self.config.blip_image_eval_size
gpu_image = transforms.Compose([ gpu_image = transforms.Compose([
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
@ -220,21 +199,21 @@ class Interrogator():
max_length=self.config.blip_max_length, max_length=self.config.blip_max_length,
min_length=5 min_length=5
) )
if self.config.blip_offload:
self.blip_model = self.blip_model.to("cpu")
return caption[0] return caption[0]
def image_to_features(self, image: Image) -> torch.Tensor: def image_to_features(self, image: Image) -> torch.Tensor:
self._prepare_clip()
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.cuda.amp.autocast():
image_features = self.clip_model.encode_image(images) image_features = self.clip_model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features return image_features
def interrogate_classic(self, image: Image, max_flavors: int=3) -> str: def interrogate_classic(self, image: Image, max_flavors: int=3, caption: Optional[str]=None) -> str:
"""Classic mode creates a prompt in a standard format first describing the image, """Classic mode creates a prompt in a standard format first describing the image,
then listing the artist, trending, movement, and flavor text modifiers.""" then listing the artist, trending, movement, and flavor text modifiers."""
caption = self.generate_caption(image) caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
medium = self.mediums.rank(image_features, 1)[0] medium = self.mediums.rank(image_features, 1)[0]
@ -250,11 +229,11 @@ class Interrogator():
return _truncate_to_fit(prompt, self.tokenize) return _truncate_to_fit(prompt, self.tokenize)
def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: def interrogate_fast(self, image: Image, max_flavors: int=32, caption: Optional[str]=None) -> str:
"""Fast mode simply adds the top ranked terms after a caption. It generally results in """Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode, but the prompts better similarity between generated prompt and image than classic mode, but the prompts
are less readable.""" are less readable."""
caption = self.generate_caption(image) caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, max_flavors) tops = merged.rank(image_features, max_flavors)
@ -269,22 +248,22 @@ class Interrogator():
flaves = flaves + self.negative.labels flaves = flaves + self.negative.labels
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain") return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain")
def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32) -> str: def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32, caption: Optional[str]=None) -> str:
caption = self.generate_caption(image) caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count) flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt, best_sim = caption, self.similarity(image_features, caption) best_prompt, best_sim = caption, self.similarity(image_features, caption)
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain") best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain")
fast_prompt = self.interrogate_fast(image, max_flavors) fast_prompt = self.interrogate_fast(image, max_flavors, caption=caption)
classic_prompt = self.interrogate_classic(image, max_flavors) classic_prompt = self.interrogate_classic(image, max_flavors, caption=caption)
candidates = [caption, classic_prompt, fast_prompt, best_prompt] candidates = [caption, classic_prompt, fast_prompt, best_prompt]
return candidates[np.argmax(self.similarities(image_features, candidates))] return candidates[np.argmax(self.similarities(image_features, candidates))]
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
self._prepare_clip()
text_tokens = self.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
@ -295,6 +274,7 @@ class Interrogator():
return text_array[similarity.argmax().item()] return text_array[similarity.argmax().item()]
def similarity(self, image_features: torch.Tensor, text: str) -> float: def similarity(self, image_features: torch.Tensor, text: str) -> float:
self._prepare_clip()
text_tokens = self.tokenize([text]).to(self.device) text_tokens = self.tokenize([text]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
@ -303,6 +283,7 @@ class Interrogator():
return similarity[0][0].item() return similarity[0][0].item()
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]: def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]:
self._prepare_clip()
text_tokens = self.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
@ -310,6 +291,22 @@ class Interrogator():
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
return similarity.T[0].tolist() return similarity.T[0].tolist()
def _prepare_blip(self):
if self.config.clip_offload and not self.clip_offloaded:
self.clip_model = self.clip_model.to('cpu')
self.clip_offloaded = True
if self.blip_offloaded:
self.blip_model = self.blip_model.to(self.device)
self.blip_offloaded = False
def _prepare_clip(self):
if self.config.blip_offload:
self.blip_model = self.blip_model.to('cpu')
self.blip_offloaded = True
if self.clip_offloaded:
self.clip_model = self.clip_model.to(self.device)
self.clip_offloaded = False
class LabelTable(): class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
@ -352,23 +349,25 @@ class LabelTable():
if self.config.cache_path is None or desc is None: if self.config.cache_path is None or desc is None:
return False return False
# load from old pkl format if it exists cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
cached_pkl = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.pkl")
if os.path.exists(cached_pkl): if self.config.download_cache and not os.path.exists(cached_safetensors):
with open(cached_pkl, 'rb') as f: download_url = CACHE_URL_BASE + f"{sanitized_name}_{desc}.safetensors"
try: try:
data = pickle.load(f) os.makedirs(self.config.cache_path, exist_ok=True)
if data.get('hash') == hash: _download_file(download_url, cached_safetensors, quiet=self.config.quiet)
self.labels = data['labels']
self.embeds = data['embeds']
return True
except Exception as e: except Exception as e:
print(f"Error loading cached table {desc}: {e}") print(f"Failed to download {download_url}")
print(e)
return False
# load from new safetensors format if it exists
cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors")
if os.path.exists(cached_safetensors): if os.path.exists(cached_safetensors):
try:
tensors = load_file(cached_safetensors) tensors = load_file(cached_safetensors)
except Exception as e:
print(f"Failed to load {cached_safetensors}")
print(e)
return False
if 'hash' in tensors and 'embeds' in tensors: if 'hash' in tensors and 'embeds' in tensors:
if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)): if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)):
self.embeds = tensors['embeds'] self.embeds = tensors['embeds']
@ -378,7 +377,6 @@ class LabelTable():
return False return False
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
top_count = min(top_count, len(text_embeds)) top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
@ -409,8 +407,11 @@ class LabelTable():
return [top_labels[i] for i in tops] return [top_labels[i] for i in tops]
def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bool = False): def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet: bool = False):
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
if r.status_code != 200:
return
file_size = int(r.headers.get("Content-Length", 0)) file_size = int(r.headers.get("Content-Length", 0))
filename = url.split("/")[-1] filename = url.split("/")[-1]
progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet) progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet)

3
run_cli.py

@ -24,6 +24,7 @@ def main():
parser.add_argument('-f', '--folder', help='path to folder of images') parser.add_argument('-f', '--folder', help='path to folder of images')
parser.add_argument('-i', '--image', help='image file or url') parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM")
args = parser.parse_args() args = parser.parse_args()
if not args.folder and not args.image: if not args.folder and not args.image:
@ -51,6 +52,8 @@ def main():
# generate a nice prompt # generate a nice prompt
config = Config(device=device, clip_model_name=args.clip) config = Config(device=device, clip_model_name=args.clip)
if args.lowvram:
config.apply_low_vram_defaults()
ci = Interrogator(config) ci = Interrogator(config)
# process single image # process single image

6
run_gradio.py

@ -11,13 +11,17 @@ except ImportError:
exit(1) exit(1)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--lowvram", action='store_true', help="Optimize settings for low VRAM")
parser.add_argument('-s', '--share', action='store_true', help='Create a public link') parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
args = parser.parse_args() args = parser.parse_args()
if not torch.cuda.is_available(): if not torch.cuda.is_available():
print("CUDA is not available, using CPU. Warning: this will be very slow!") print("CUDA is not available, using CPU. Warning: this will be very slow!")
ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) config = Config(cache_path="cache")
if args.lowvram:
config.apply_low_vram_defaults()
ci = Interrogator(config)
def image_analysis(image, clip_model_name): def image_analysis(image, clip_model_name):
if clip_model_name != ci.config.clip_model_name: if clip_model_name != ci.config.clip_model_name:

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.5.1", version="0.5.2",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save