Browse Source

0.4.2:

- upgrade chain to take a min_count parameter so it won't early out until it has considered at least min_count flavors
- interrogate method ("best" mode) also checks against classic and fast to use their output if it's better
- fix bug of config.download_cache option not being used!
- add notes on Config object to readme
pull/43/head
pharmapsychotic 2 years ago
parent
commit
78287e17e1
  1. 12
      README.md
  2. 2
      clip_interrogator/__init__.py
  3. 41
      clip_interrogator/clip_interrogator.py
  4. 2
      setup.py

12
README.md

@ -36,7 +36,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator
pip install clip-interrogator==0.4.1
pip install clip-interrogator==0.4.2
```
You can then use it in your script
@ -51,3 +51,13 @@ print(ci.interrogate(image))
CLIP Interrogator uses OpenCLIP which supports many different pretrained CLIP models. For the best prompts for
Stable Diffusion 1.X use `ViT-L-14/openai` for clip_model_name. For Stable Diffusion 2.0 use `ViT-H-14/laion2b_s32b_b79k`
## Configuration
The `Config` object lets you configure CLIP Interrogator's processing.
* `clip_model_name`: which of the OpenCLIP pretrained CLIP models to use
* `cache_path`: path where to save precomputed text embeddings
* `download_cache`: when True will download the precomputed embeddings from huggingface
* `chunk_size`: batch size for CLIP, use smaller for lower VRAM
* `quiet`: when True no progress bars or text output will be displayed
See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes.

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config
__version__ = '0.4.1'
__version__ = '0.4.2'
__author__ = 'pharmapsychotic'

41
clip_interrogator/clip_interrogator.py

@ -140,7 +140,8 @@ class Interrogator():
artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists])
self.download_cache(config.clip_model_name)
if config.download_cache:
self.download_cache(config.clip_model_name)
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
@ -159,6 +160,7 @@ class Interrogator():
phrases: List[str],
best_prompt: str="",
best_sim: float=0,
min_count: int=8,
max_count: int=32,
desc="Chaining",
reverse: bool=False
@ -168,25 +170,28 @@ class Interrogator():
best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse)
best_sim = self.similarity(image_features, best_prompt)
phrases.remove(best_prompt)
def check(addition: str) -> bool:
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
curr_prompt, curr_sim = best_prompt, best_sim
def check(addition: str, idx: int) -> bool:
nonlocal best_prompt, best_sim, curr_prompt, curr_sim
prompt = curr_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if reverse:
sim = -sim
if sim > best_sim:
best_sim = sim
best_prompt = prompt
best_prompt, best_sim = prompt, sim
if sim > curr_sim or idx < min_count:
curr_prompt, curr_sim = prompt, sim
return True
return False
for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse)
flave = best[len(best_prompt)+2:]
if not check(flave):
for idx in tqdm(range(max_count), desc=desc, disable=self.config.quiet):
best = self.rank_top(image_features, [f"{curr_prompt}, {f}" for f in phrases], reverse=reverse)
flave = best[len(curr_prompt)+2:]
if not check(flave, idx):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
if _prompt_at_max_len(curr_prompt, self.tokenize):
break
phrases.remove(flave)
@ -259,22 +264,26 @@ class Interrogator():
flaves = flaves + self.negative.labels
return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain")
def interrogate(self, image: Image, max_flavors: int=32) -> str:
def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt = caption
best_sim = self.similarity(image_features, best_prompt)
best_prompt, best_sim = caption, self.similarity(image_features, best_prompt)
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain")
return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain")
fast_prompt = self.interrogate_fast(image, max_flavors)
classic_prompt = self.interrogate_classic(image, max_flavors)
candidates = [caption, classic_prompt, fast_prompt, best_prompt]
return candidates[np.argmax(self.similarities(image_features, candidates))]
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
if reverse:
similarity = -similarity

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup(
name="clip-interrogator",
version="0.4.1",
version="0.4.2",
license='MIT',
author='pharmapsychotic',
author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save