Browse Source

0.4.2:

- upgrade chain to take a min_count parameter so it won't early out until it has considered at least min_count flavors
- interrogate method ("best" mode) also checks against classic and fast to use their output if it's better
- fix bug of config.download_cache option not being used!
- add notes on Config object to readme
pull/43/head
pharmapsychotic 2 years ago
parent
commit
78287e17e1
  1. 12
      README.md
  2. 2
      clip_interrogator/__init__.py
  3. 41
      clip_interrogator/clip_interrogator.py
  4. 2
      setup.py

12
README.md

@ -36,7 +36,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator # install clip-interrogator
pip install clip-interrogator==0.4.1 pip install clip-interrogator==0.4.2
``` ```
You can then use it in your script You can then use it in your script
@ -51,3 +51,13 @@ print(ci.interrogate(image))
CLIP Interrogator uses OpenCLIP which supports many different pretrained CLIP models. For the best prompts for CLIP Interrogator uses OpenCLIP which supports many different pretrained CLIP models. For the best prompts for
Stable Diffusion 1.X use `ViT-L-14/openai` for clip_model_name. For Stable Diffusion 2.0 use `ViT-H-14/laion2b_s32b_b79k` Stable Diffusion 1.X use `ViT-L-14/openai` for clip_model_name. For Stable Diffusion 2.0 use `ViT-H-14/laion2b_s32b_b79k`
## Configuration
The `Config` object lets you configure CLIP Interrogator's processing.
* `clip_model_name`: which of the OpenCLIP pretrained CLIP models to use
* `cache_path`: path where to save precomputed text embeddings
* `download_cache`: when True will download the precomputed embeddings from huggingface
* `chunk_size`: batch size for CLIP, use smaller for lower VRAM
* `quiet`: when True no progress bars or text output will be displayed
See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes.

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config from .clip_interrogator import Interrogator, Config
__version__ = '0.4.1' __version__ = '0.4.2'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

41
clip_interrogator/clip_interrogator.py

@ -140,7 +140,8 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists] artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists]) artists.extend([f"inspired by {a}" for a in raw_artists])
self.download_cache(config.clip_model_name) if config.download_cache:
self.download_cache(config.clip_model_name)
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
@ -159,6 +160,7 @@ class Interrogator():
phrases: List[str], phrases: List[str],
best_prompt: str="", best_prompt: str="",
best_sim: float=0, best_sim: float=0,
min_count: int=8,
max_count: int=32, max_count: int=32,
desc="Chaining", desc="Chaining",
reverse: bool=False reverse: bool=False
@ -168,25 +170,28 @@ class Interrogator():
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse) best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
best_sim = self.similarity(image_features, best_prompt) best_sim = self.similarity(image_features, best_prompt)
phrases.remove(best_prompt) phrases.remove(best_prompt)
curr_prompt, curr_sim = best_prompt, best_sim
def check(addition: str) -> bool:
nonlocal best_prompt, best_sim def check(addition: str, idx: int) -> bool:
prompt = best_prompt + ", " + addition nonlocal best_prompt, best_sim, curr_prompt, curr_sim
prompt = curr_prompt + ", " + addition
sim = self.similarity(image_features, prompt) sim = self.similarity(image_features, prompt)
if reverse: if reverse:
sim = -sim sim = -sim
if sim > best_sim: if sim > best_sim:
best_sim = sim best_prompt, best_sim = prompt, sim
best_prompt = prompt if sim > curr_sim or idx < min_count:
curr_prompt, curr_sim = prompt, sim
return True return True
return False return False
for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet): for idx in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse) best = self.rank_top(image_features, [f"{curr_prompt}, {f}" for f in phrases], reverse=reverse)
flave = best[len(best_prompt)+2:] flave = best[len(curr_prompt)+2:]
if not check(flave): if not check(flave, idx):
break break
if _prompt_at_max_len(best_prompt, self.tokenize): if _prompt_at_max_len(curr_prompt, self.tokenize):
break break
phrases.remove(flave) phrases.remove(flave)
@ -259,22 +264,26 @@ class Interrogator():
flaves = flaves + self.negative.labels flaves = flaves + self.negative.labels
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain") return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain")
def interrogate(self, image: Image, max_flavors: int=32) -> str: def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count) flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt = caption best_prompt, best_sim = caption, self.similarity(image_features, best_prompt)
best_sim = self.similarity(image_features, best_prompt) best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain")
return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain") fast_prompt = self.interrogate_fast(image, max_flavors)
classic_prompt = self.interrogate_classic(image, max_flavors)
candidates = [caption, classic_prompt, fast_prompt, best_prompt]
return candidates[np.argmax(self.similarities(image_features, candidates))]
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
if reverse: if reverse:
similarity = -similarity similarity = -similarity

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.4.1", version="0.4.2",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save