diff --git a/README.md b/README.md index 9885224..ff1e3bd 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Install with PIP pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 # install clip-interrogator -pip install clip-interrogator==0.4.1 +pip install clip-interrogator==0.4.2 ``` You can then use it in your script @@ -51,3 +51,13 @@ print(ci.interrogate(image)) CLIP Interrogator uses OpenCLIP which supports many different pretrained CLIP models. For the best prompts for Stable Diffusion 1.X use `ViT-L-14/openai` for clip_model_name. For Stable Diffusion 2.0 use `ViT-H-14/laion2b_s32b_b79k` +## Configuration + +The `Config` object lets you configure CLIP Interrogator's processing. +* `clip_model_name`: which of the OpenCLIP pretrained CLIP models to use +* `cache_path`: path where to save precomputed text embeddings +* `download_cache`: when True will download the precomputed embeddings from huggingface +* `chunk_size`: batch size for CLIP, use smaller for lower VRAM +* `quiet`: when True no progress bars or text output will be displayed + +See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes. diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 4530f1c..bc18e62 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ from .clip_interrogator import Interrogator, Config -__version__ = '0.4.1' +__version__ = '0.4.2' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index db212d2..bde295f 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -140,7 +140,8 @@ class Interrogator(): artists = [f"by {a}" for a in raw_artists] artists.extend([f"inspired by {a}" for a in raw_artists]) - self.download_cache(config.clip_model_name) + if config.download_cache: + self.download_cache(config.clip_model_name) self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) @@ -159,6 +160,7 @@ class Interrogator(): phrases: List[str], best_prompt: str="", best_sim: float=0, + min_count: int=8, max_count: int=32, desc="Chaining", reverse: bool=False @@ -168,25 +170,28 @@ class Interrogator(): best_prompt = self.rank_top(image_features, [f for f in phrases], reverse=reverse) best_sim = self.similarity(image_features, best_prompt) phrases.remove(best_prompt) - - def check(addition: str) -> bool: - nonlocal best_prompt, best_sim - prompt = best_prompt + ", " + addition + curr_prompt, curr_sim = best_prompt, best_sim + + def check(addition: str, idx: int) -> bool: + nonlocal best_prompt, best_sim, curr_prompt, curr_sim + prompt = curr_prompt + ", " + addition sim = self.similarity(image_features, prompt) if reverse: sim = -sim + if sim > best_sim: - best_sim = sim - best_prompt = prompt + best_prompt, best_sim = prompt, sim + if sim > curr_sim or idx < min_count: + curr_prompt, curr_sim = prompt, sim return True return False - for _ in tqdm(range(max_count), desc=desc, disable=self.config.quiet): - best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in phrases], reverse=reverse) - flave = best[len(best_prompt)+2:] - if not check(flave): + for idx in tqdm(range(max_count), desc=desc, disable=self.config.quiet): + best = self.rank_top(image_features, [f"{curr_prompt}, {f}" for f in phrases], reverse=reverse) + flave = best[len(curr_prompt)+2:] + if not check(flave, idx): break - if _prompt_at_max_len(best_prompt, self.tokenize): + if _prompt_at_max_len(curr_prompt, self.tokenize): break phrases.remove(flave) @@ -259,22 +264,26 @@ class Interrogator(): flaves = flaves + self.negative.labels return self.chain(image_features, flaves, max_count=max_flavors, reverse=True, desc="Negative chain") - def interrogate(self, image: Image, max_flavors: int=32) -> str: + def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32) -> str: caption = self.generate_caption(image) image_features = self.image_to_features(image) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) flaves = merged.rank(image_features, self.config.flavor_intermediate_count) - best_prompt = caption - best_sim = self.similarity(image_features, best_prompt) + best_prompt, best_sim = caption, self.similarity(image_features, best_prompt) + best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain") - return self.chain(image_features, flaves, best_prompt, best_sim, max_count=max_flavors, desc="Flavor chain") + fast_prompt = self.interrogate_fast(image, max_flavors) + classic_prompt = self.interrogate_classic(image, max_flavors) + candidates = [caption, classic_prompt, fast_prompt, best_prompt] + return candidates[np.argmax(self.similarities(image_features, candidates))] def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: text_tokens = self.tokenize([text for text in text_array]).to(self.device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = self.clip_model.encode_text(text_tokens) + text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T if reverse: similarity = -similarity diff --git a/setup.py b/setup.py index 449c6e9..e862dc4 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name="clip-interrogator", - version="0.4.1", + version="0.4.2", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',