Browse Source

Lower VRAM usage by only having one model loaded at a time

This changes `Interrogator` to only load BLIP to VRAM on init, and leave CLIP in
RAM until it's needed.

When `interrogate` is first called, it does BLIP inference, unloads it, loads
CLIP, then does CLIP inference. 'Unloaded' in this case just means 'in RAM'.

Using this, I can run classic/fast interrogation on 4GB of VRAM, 'best' is still
a little too big however.

This commit also includes automatic `black` formatting and extra type hints,
which can be removed if you want.
pull/46/head
bolshoytoster 2 years ago
parent
commit
8cf9ff3990
  1. 54
      README.md
  2. 503
      clip_interrogator/clip_interrogator.py
  3. 75
      run_cli.py

54
README.md

@ -1,9 +1,61 @@
# clip-interrogator
# clip-interrogator-with-less-VRAM
*Want to figure out what a good prompt might be to create new images like an existing one? The **CLIP Interrogator** is here to get you answers!*
This version uses less VRAM than the main repo by only having one model loaded at a time.
When you create an `Interrogator`:
```py
ci = Interrogator(Config())
```
The BLIP and CLIP models are both loaded, but only BLIP is on the GPU, the CLIP stays in RAM.
When you actually do inference:
```py
ci.interrogate(image)
# Or:
# ci.interrogate_classic(image)
# ci.interrogate_fast(image)
```
BLIP inference is done, it gets unloaded then CLIP gets loaded and infers.
If you run it again, CLIP is done first, then BLIP is loaded, to reduce pointless loading and unloading.
By using this, it (`classic` or `fast`, normal doesn't quite fit) can be run on as little as 4GB of VRAM, the main repo needing at least 6GB.
> But wouldn't loading a new model every time I want to interrogate an image be terrible for performance?
\- me
Absolutely.
There's little performance overhead for just one interrogation, since it's essentially lazy loading the CLIP model, but for multiple images, there will be a noticable effect.
That's why I made the `interrogate_batch` functions:
```py
# files = Some list of strings
images = [Image.open(f).convert("RGB") for f in files]
ci.interrogate_batch(images)
```
This does BLIP inference on each of the images, *then* loads the CLIP model, saving some performance.
There are also `interrogate_{classic,fast}_batch` functions.
## Run it!
Bash (linux/unix):
```sh
$ ./run_cli.py -i input.png -m $MODE
```
Windows:
```cmd
$ python run_cli.py -i input.png -m $MODE
```
Where `$MODE` is either `best`, `classic` or `fast` (default `best`)
Run Version 2 on Colab, HuggingFace, and Replicate!
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [![Replicate](https://replicate.com/pharmapsychotic/clip-interrogator/badge)](https://replicate.com/pharmapsychotic/clip-interrogator)

503
clip_interrogator/clip_interrogator.py

@ -18,24 +18,24 @@ from tqdm import tqdm
from typing import List
BLIP_MODELS = {
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
"base": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth",
"large": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
}
CACHE_URLS_VITL = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl",
]
CACHE_URLS_VITH = [
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl",
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl",
]
@ -49,40 +49,51 @@ class Config:
# blip settings
blip_image_eval_size: int = 384
blip_max_length: int = 32
blip_model_type: str = 'large' # choose between 'base' or 'large'
blip_model_type: str = "large" # choose between 'base' or 'large'
blip_num_beams: int = 8
blip_offload: bool = False
# clip settings
clip_model_name: str = 'ViT-L-14/openai'
clip_model_name: str = "ViT-L-14/openai"
clip_model_path: str = None
# interrogator settings
cache_path: str = 'cache' # path to store cached text embeddings
download_cache: bool = True # when true, cached embeds are downloaded from huggingface
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
cache_path: str = "cache" # path to store cached text embeddings
download_cache: bool = (
True # when true, cached embeds are downloaded from huggingface
)
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), "data")
device: str = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown
quiet: bool = False # when quiet progress bars are not shown
class Interrogator():
class Interrogator:
def __init__(self, config: Config):
self.config = config
self.device = config.device
# Record which model is on the target device
self.blip_loaded = True
# Load BLIP model (to intended device)
if config.blip_model is None:
if not config.quiet:
print("Loading BLIP model...")
blip_path = os.path.dirname(inspect.getfile(blip_decoder))
configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
med_config = os.path.join(configs_path, 'med_config.json')
configs_path = os.path.join(os.path.dirname(blip_path), "configs")
med_config = os.path.join(configs_path, "med_config.json")
blip_model = blip_decoder(
pretrained=BLIP_MODELS[config.blip_model_type],
image_size=config.blip_image_eval_size,
vit=config.blip_model_type,
med_config=med_config
med_config=med_config,
)
blip_model.eval()
blip_model = blip_model.to(config.device)
@ -90,12 +101,13 @@ class Interrogator():
else:
self.blip_model = config.blip_model
# Load CLIP (to CPU)
self.load_clip_model()
def download_cache(self, clip_model_name: str):
if clip_model_name == 'ViT-L-14/openai':
if clip_model_name == "ViT-L-14/openai":
cache_urls = CACHE_URLS_VITL
elif clip_model_name == 'ViT-H-14/laion2b_s32b_b79k':
elif clip_model_name == "ViT-H-14/laion2b_s32b_b79k":
cache_urls = CACHE_URLS_VITH
else:
# text embeddings will be precomputed and cached locally
@ -103,7 +115,7 @@ class Interrogator():
os.makedirs(self.config.cache_path, exist_ok=True)
for url in cache_urls:
filepath = os.path.join(self.config.cache_path, url.split('/')[-1])
filepath = os.path.join(self.config.cache_path, url.split("/")[-1])
if not os.path.exists(filepath):
_download_file(url, filepath, quiet=self.config.quiet)
@ -115,40 +127,93 @@ class Interrogator():
if not config.quiet:
print("Loading CLIP model...")
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split(
"/", 2
)
(
self.clip_model,
_,
self.clip_preprocess,
) = open_clip.create_model_and_transforms(
clip_model_name,
pretrained=clip_model_pretrained_name,
precision='fp16' if config.device == 'cuda' else 'fp32',
device=config.device,
precision="fp16" if config.device == "cuda" else "fp32",
device="cpu",
jit=False,
cache_dir=config.clip_model_path
cache_dir=config.clip_model_path,
)
self.clip_model.to(config.device).eval()
self.clip_model.eval()
else:
self.clip_model = config.clip_model
self.clip_preprocess = config.clip_preprocess
self.tokenize = open_clip.get_tokenizer(clip_model_name)
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
sites = [
"Artstation",
"behance",
"cg society",
"cgsociety",
"deviantart",
"dribble",
"flickr",
"instagram",
"pexels",
"pinterest",
"pixabay",
"pixiv",
"polycount",
"reddit",
"shutterstock",
"tumblr",
"unsplash",
"zbrush central",
]
trending_list = [site for site in sites]
trending_list.extend(["trending on "+site for site in sites])
trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites])
trending_list.extend(["trending on " + site for site in sites])
trending_list.extend(["featured on " + site for site in sites])
trending_list.extend([site + " contest winner" for site in sites])
raw_artists = _load_list(config.data_path, 'artists.txt')
raw_artists = _load_list(config.data_path, "artists.txt")
artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists])
if config.download_cache:
self.download_cache(config.clip_model_name)
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)
self.artists = LabelTable(
artists, "artists", self.clip_model, self.tokenize, config
)
self.flavors = LabelTable(
_load_list(config.data_path, "flavors.txt"),
"flavors",
self.clip_model,
self.tokenize,
config,
)
self.mediums = LabelTable(
_load_list(config.data_path, "mediums.txt"),
"mediums",
self.clip_model,
self.tokenize,
config,
)
self.movements = LabelTable(
_load_list(config.data_path, "movements.txt"),
"movements",
self.clip_model,
self.tokenize,
config,
)
self.trendings = LabelTable(
trending_list, "trendings", self.clip_model, self.tokenize, config
)
self.negative = LabelTable(
_load_list(config.data_path, "negative.txt"),
"negative",
self.clip_model,
self.tokenize,
config,
)
end_time = time.time()
if not config.quiet:
@ -158,16 +223,18 @@ class Interrogator():
self,
image_features: torch.Tensor,
phrases: List[str],
best_prompt: str="",
best_sim: float=0,
min_count: int=8,
max_count: int=32,
best_prompt: str = "",
best_sim: float = 0,
min_count: int = 8,
max_count: int = 32,
desc="Chaining",
reverse: bool=False
reverse: bool = False,
) -> str:
phrases = set(phrases)
if not best_prompt:
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
best_prompt = self.rank_top(
image_features, [f for f in phrases], reverse=reverse
)
best_sim = self.similarity(image_features, best_prompt)
phrases.remove(best_prompt)
curr_prompt, curr_sim = best_prompt, best_sim
@ -187,8 +254,12 @@ class Interrogator():
return False
for idx in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
best = self.rank_top(image_features, [f"{curr_prompt}, {f}" for f in phrases], reverse=reverse)
flave = best[len(curr_prompt)+2:]
best = self.rank_top(
image_features,
[f"{curr_prompt}, {f}" for f in phrases],
reverse=reverse,
)
flave = best[len(curr_prompt) + 2 :]
if not check(flave, idx):
break
if _prompt_at_max_len(curr_prompt, self.tokenize):
@ -201,11 +272,22 @@ class Interrogator():
if self.config.blip_offload:
self.blip_model = self.blip_model.to(self.device)
size = self.config.blip_image_eval_size
gpu_image = transforms.Compose([
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(self.device)
gpu_image = (
transforms.Compose(
[
transforms.Resize(
(size, size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)(pil_image)
.unsqueeze(0)
.to(self.device)
)
with torch.no_grad():
caption = self.blip_model.generate(
@ -213,7 +295,7 @@ class Interrogator():
sample=False,
num_beams=self.config.blip_num_beams,
max_length=self.config.blip_max_length,
min_length=5
min_length=5,
)
if self.config.blip_offload:
self.blip_model = self.blip_model.to("cpu")
@ -226,17 +308,65 @@ class Interrogator():
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
def interrogate_classic(self, image: Image, max_flavors: int=3) -> str:
"""Classic mode creates a prompt in a standard format first describing the image,
then listing the artist, trending, movement, and flavor text modifiers."""
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
def _first_bit(self, image: Image) -> (str, torch.Tensor):
if self.blip_loaded:
caption = self.generate_caption(image)
# Move BLIP to RAM
self.blip_model.to("cpu")
# Move CLIP to intended device
self.clip_model.to(self.device)
image_features = self.image_to_features(image)
else: # CLIP is loaded
image_features = self.image_to_features(image)
# Move CLIP to RAM
self.clip_model.to("cpu")
# Move BLIP to intended device
self.blip_model.to(self.device)
caption = self.generate_caption(image)
# Toggle `blip_loaded`
self.blip_loaded ^= True
return caption, image_features
def _first_bit_batch(self, images: list[Image]) -> (list[str], list[torch.Tensor]):
image_features: list[torch.Tensor] = []
if self.blip_loaded:
captions = [self.generate_caption(img) for img in images]
# Move BLIP to RAM
self.blip_model.to("cpu")
# Move CLIP to intended device
self.clip_model.to(self.device)
image_features = [self.image_to_features(img) for img in images]
else: # CLIP is loaded
image_features = [self.image_to_features(img) for img in images]
# Move CLIP to RAM
self.clip_model.to("cpu")
# Move BLIP to intended device
self.blip_model.to(self.device)
captions = [self.generate_caption(img) for img in images]
# Toggle `blip_loaded`
self.blip_loaded ^= True
return captions, image_features
def _interrogate_classic(
self, caption: str, image_features: torch.Tensor, max_flavours: int = 3
) -> str:
medium = self.mediums.rank(image_features, 1)[0]
artist = self.artists.rank(image_features, 1)[0]
trending = self.trendings.rank(image_features, 1)[0]
movement = self.movements.rank(image_features, 1)[0]
flaves = ", ".join(self.flavors.rank(image_features, max_flavors))
flaves = ", ".join(self.flavors.rank(image_features, max_flavours))
if caption.startswith(medium):
prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}"
@ -245,41 +375,138 @@ class Interrogator():
return _truncate_to_fit(prompt, self.tokenize)
def interrogate_classic(self, image: Image, max_flavors: int = 3) -> str:
"""Classic mode creates a prompt in a standard format first describing the image,
then listing the artist, trending, movement, and flavor text modifiers."""
caption, image_features = self._first_bit(image)
return self._interrogate_classic(caption, image_features, max_flavors)
def interrogate_classic_batch(
self, images: list[Image], max_flavors: int = 3
) -> list[str]:
"""Classic mode creates a prompt in a standard format first describing the image,
then listing the artist, trending, movement, and flavor text modifiers.
This function interrogates a batch of images (more efficient than doing
it individually)."""
captions, image_features = self._first_bit_batch(images)
returns: list[str] = [
self._interrogate_classic(caption, feature, max_flavors)
for caption, feature in zip(captions, image_features)
]
return returns
def _interrogate_fast(
self, caption: str, image_features: torch.Tensor, max_flavours: int = 32
) -> str:
merged = _merge_tables(
[self.artists, self.flavors, self.mediums, self.movements, self.trendings],
self.config,
)
tops = merged.rank(image_features, max_flavours)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
"""Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode, but the prompts
are less readable."""
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)
caption, image_features = self._first_bit(image)
return self._interrogate_fast(caption, image_features, max_flavors)
def interrogate_fast_batch(self, images: list[Image], max_flavors: int = 32) -> str:
"""Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode, but the prompts
are less readable.
This function interrogates a batch of images (more efficient than doing
it individually)."""
captions, image_features = self._first_bit_batch(images)
returns: list[str] = [
self._interrogate_fast(caption, feature, max_flavors)
for caption, feature in zip(captions, image_features)
]
return returns
def interrogate_negative(self, image: Image, max_flavors: int = 32) -> str:
"""Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2."""
image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count, reverse=True)
flaves = flaves + self.negative.labels
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain")
if self.blip_loaded: # Move CLIP to intended device
self.blip_model.to("cpu")
self.cli_model.to(self.device)
def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
flaves = self.flavors.rank(
image_features, self.config.flavor_intermediate_count, reverse=True
)
flaves = flaves + self.negative.labels
return self.chain(
image_features,
flaves,
max_count=max_flavors,
reverse=True,
desc="Negative chain",
)
def _interrogate(
self,
caption: str,
image_features: torch.Tensor,
min_flavours: int = 8,
max_flavours: int = 32,
) -> str:
merged = _merge_tables(
[self.artists, self.flavors, self.mediums, self.movements, self.trendings],
self.config,
)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt, best_sim = caption, self.similarity(image_features, caption)
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain")
fast_prompt = self.interrogate_fast(image, max_flavors)
classic_prompt = self.interrogate_classic(image, max_flavors)
best_prompt = self.chain(
image_features,
flaves,
best_prompt,
best_sim,
min_count=min_flavours,
max_count=max_flavours,
desc="Flavor chain",
)
fast_prompt = self._interrogate_fast(caption, image_features, max_flavours)
classic_prompt = self.interrogate_classic(caption, image_features, max_flavours)
candidates = [caption, classic_prompt, fast_prompt, best_prompt]
return candidates[np.argmax(self.similarities(image_features, candidates))]
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
def interrogate(
self, image: Image, min_flavors: int = 8, max_flavors: int = 32
) -> str:
caption, image_features = self._first_bit(image)
return self._interrogate(caption, image_features, min_flavors, max_flavors)
def interrogate_batch(
self, images: list[Image], min_flavors: int = 8, max_flavors: int = 32
) -> list[str]:
"""This function interrogates a batch of images (more efficient than doing
it individually)."""
captions, image_features = self._first_bit_batch(images)
returns: list[str] = [
self._interrogate(caption, features, min_flavors, max_flavors)
for caption, features in zip(captions, image_features)
]
return returns
def rank_top(
self, image_features: torch.Tensor, text_array: List[str], reverse: bool = False
) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
@ -297,7 +524,9 @@ class Interrogator():
similarity = text_features @ image_features.T
return similarity[0][0].item()
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]:
def similarities(
self, image_features: torch.Tensor, text_array: List[str]
) -> List[float]:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
@ -306,8 +535,10 @@ class Interrogator():
return similarity.T[0].tolist()
class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
class LabelTable:
def __init__(
self, labels: List[str], desc: str, clip_model, tokenize, config: Config
):
self.chunk_size = config.chunk_size
self.config = config
self.device = config.device
@ -320,22 +551,30 @@ class LabelTable():
cache_filepath = None
if config.cache_path is not None and desc is not None:
os.makedirs(config.cache_path, exist_ok=True)
sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_')
cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl")
sanitized_name = config.clip_model_name.replace("/", "_").replace("@", "_")
cache_filepath = os.path.join(
config.cache_path, f"{sanitized_name}_{desc}.pkl"
)
if desc is not None and os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f:
with open(cache_filepath, "rb") as f:
try:
data = pickle.load(f)
if data.get('hash') == hash:
self.labels = data['labels']
self.embeds = data['embeds']
if data.get("hash") == hash:
self.labels = data["labels"]
self.embeds = data["embeds"]
except Exception as e:
print(f"Error loading cached table {desc}: {e}")
if len(self.labels) != len(self.embeds):
self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
chunks = np.array_split(
self.labels, max(1, len(self.labels) / config.chunk_size)
)
for chunk in tqdm(
chunks,
desc=f"Preprocessing {desc}" if desc else None,
disable=self.config.quiet,
):
text_tokens = self.tokenize(chunk).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text_tokens)
@ -345,20 +584,31 @@ class LabelTable():
self.embeds.append(text_features[i])
if cache_filepath is not None:
with open(cache_filepath, 'wb') as f:
pickle.dump({
"labels": self.labels,
"embeds": self.embeds,
"hash": hash,
"model": config.clip_model_name
}, f)
if self.device == 'cpu' or self.device == torch.device('cpu'):
with open(cache_filepath, "wb") as f:
pickle.dump(
{
"labels": self.labels,
"embeds": self.embeds,
"hash": hash,
"model": config.clip_model_name,
},
f,
)
if self.device == "cpu" or self.device == torch.device("cpu"):
self.embeds = [e.astype(np.float32) for e in self.embeds]
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
def _rank(
self,
image_features: torch.Tensor,
text_embeds: torch.Tensor,
top_count: int = 1,
reverse: bool = False,
) -> str:
top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(
self.device
)
with torch.cuda.amp.autocast():
similarity = image_features @ text_embeds.T
if reverse:
@ -366,31 +616,44 @@ class LabelTable():
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features: torch.Tensor, top_count: int=1, reverse: bool=False) -> List[str]:
def rank(
self, image_features: torch.Tensor, top_count: int = 1, reverse: bool = False
) -> List[str]:
if len(self.labels) <= self.chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count, reverse=reverse)
tops = self._rank(
image_features, self.embeds, top_count=top_count, reverse=reverse
)
return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels)/self.chunk_size))
num_chunks = int(math.ceil(len(self.labels) / self.chunk_size))
keep_per_chunk = int(self.chunk_size / num_chunks)
top_labels, top_embeds = [], []
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
start = chunk_idx*self.chunk_size
stop = min(start+self.chunk_size, len(self.embeds))
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk, reverse=reverse)
top_labels.extend([self.labels[start+i] for i in tops])
top_embeds.extend([self.embeds[start+i] for i in tops])
start = chunk_idx * self.chunk_size
stop = min(start + self.chunk_size, len(self.embeds))
tops = self._rank(
image_features,
self.embeds[start:stop],
top_count=keep_per_chunk,
reverse=reverse,
)
top_labels.extend([self.labels[start + i] for i in tops])
top_embeds.extend([self.embeds[start + i] for i in tops])
tops = self._rank(image_features, top_embeds, top_count=top_count)
return [top_labels[i] for i in tops]
def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bool = False):
def _download_file(
url: str, filepath: str, chunk_size: int = 64 * 1024, quiet: bool = False
):
r = requests.get(url, stream=True)
file_size = int(r.headers.get("Content-Length", 0))
filename = url.split("/")[-1]
progress = tqdm(total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet)
progress = tqdm(
total=file_size, unit="B", unit_scale=True, desc=filename, disable=quiet
)
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
@ -398,11 +661,15 @@ def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bo
progress.update(len(chunk))
progress.close()
def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
with open(
os.path.join(data_path, filename), "r", encoding="utf-8", errors="replace"
) as f:
items = [line.strip() for line in f.readlines()]
return items
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, None, config)
for table in tables:
@ -410,15 +677,17 @@ def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m.embeds.extend(table.embeds)
return m
def _prompt_at_max_len(text: str, tokenize) -> bool:
tokens = tokenize([text])
return tokens[0][-1] != 0
def _truncate_to_fit(text: str, tokenize) -> str:
parts = text.split(', ')
parts = text.split(", ")
new_text = parts[0]
for part in parts[1:]:
if _prompt_at_max_len(new_text + part, tokenize):
break
new_text += ', ' + part
new_text += ", " + part
return new_text

75
run_cli.py

@ -8,22 +8,37 @@ import torch
from PIL import Image
from clip_interrogator import Interrogator, Config
def inference(ci, image, mode):
image = image.convert('RGB')
if mode == 'best':
def inference(ci: Interrogator, image: Image, mode: str) -> str:
image = image.convert("RGB")
if mode == "best":
return ci.interrogate(image)
elif mode == 'classic':
elif mode == "classic":
return ci.interrogate_classic(image)
else:
return ci.interrogate_fast(image)
def inference_batch(ci: Interrogator, images: list[Image], mode: str) -> list[str]:
if mode == "best":
return ci.interrogate_batch(images)
elif mode == "classic":
return ci.interrogate_classic_batch(images)
else:
return ci.interrogate_fast_batch(images)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--clip', default='ViT-L-14/openai', help='name of CLIP model to use')
parser.add_argument('-d', '--device', default='auto', help='device to use (auto, cuda or cpu)')
parser.add_argument('-f', '--folder', help='path to folder of images')
parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
parser.add_argument(
"-c", "--clip", default="ViT-L-14/openai", help="name of CLIP model to use"
)
parser.add_argument(
"-d", "--device", default="auto", help="device to use (auto, cuda or cpu)"
)
parser.add_argument("-f", "--folder", help="path to folder of images")
parser.add_argument("-i", "--image", help="image file or url")
parser.add_argument("-m", "--mode", default="best", help="best, classic, or fast")
args = parser.parse_args()
if not args.folder and not args.image:
@ -35,15 +50,15 @@ def main():
exit(1)
# validate clip model name
models = ['/'.join(x) for x in open_clip.list_pretrained()]
models = ["/".join(x) for x in open_clip.list_pretrained()]
if args.clip not in models:
print(f"Could not find CLIP model {args.clip}!")
print(f" available models: {models}")
exit(1)
# select device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("CUDA is not available, using CPU. Warning: this will be very slow!")
else:
@ -56,38 +71,46 @@ def main():
# process single image
if args.image is not None:
image_path = args.image
if str(image_path).startswith('http://') or str(image_path).startswith('https://'):
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
if str(image_path).startswith("http://") or str(image_path).startswith(
"https://"
):
image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
else:
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
if not image:
print(f'Error opening image {image_path}')
print(f"Error opening image {image_path}")
exit(1)
print(inference(ci, image, args.mode))
# process folder of images
elif args.folder is not None:
if not os.path.exists(args.folder):
print(f'The folder {args.folder} does not exist!')
print(f"The folder {args.folder} does not exist!")
exit(1)
files = [f for f in os.listdir(args.folder) if f.endswith('.jpg') or f.endswith('.png')]
prompts = []
for file in files:
image = Image.open(os.path.join(args.folder, file)).convert('RGB')
prompt = inference(ci, image, args.mode)
prompts.append(prompt)
files = [
f
for f in os.listdir(args.folder)
if f.endswith(".jpg") or f.endswith(".png")
]
prompts = inference_batch(
ci,
[Image.open(os.path.join(args.folder, f)).convert("RGB") for f in files],
args.mode,
)
for prompt in prompts:
print(prompt)
if len(prompts):
csv_path = os.path.join(args.folder, 'desc.csv')
with open(csv_path, 'w', encoding='utf-8', newline='') as f:
csv_path = os.path.join(args.folder, "desc.csv")
with open(csv_path, "w", encoding="utf-8", newline="") as f:
w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
w.writerow(['image', 'prompt'])
w.writerow(["image", "prompt"])
for file, prompt in zip(files, prompts):
w.writerow([file, prompt])
print(f"\n\n\n\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!")
if __name__ == "__main__":
main()

Loading…
Cancel
Save