From ac7490490812cc62eb494cc34687364b8aa6082b Mon Sep 17 00:00:00 2001 From: pharmapsychotic Date: Sun, 19 Mar 2023 20:02:23 -0500 Subject: [PATCH] Expose LabelTable and load_list and give example in README how they can be used to rank your own list of terms. --- README.md | 16 +++++++++- clip_interrogator/__init__.py | 4 +-- clip_interrogator/clip_interrogator.py | 41 +++++++++++++++----------- setup.py | 2 +- 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 27984b0..acdce7a 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ Install with PIP pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 # install clip-interrogator -pip install clip-interrogator==0.5.4 +pip install clip-interrogator==0.5.5 ``` You can then use it in your script @@ -67,3 +67,17 @@ The `Config` object lets you configure CLIP Interrogator's processing. On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB. See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes. + + +## Ranking against your own list of terms + +```python +from clip_interrogator import Config, Interrogator, LabelTable, load_list +from PIL import Image + +ci = Interrogator(Config(blip_model_type=None)) +image = Image.open(image_path).convert('RGB') +table = LabelTable(load_list('terms.txt'), 'terms', ci) +best_match = table.rank(ci.image_to_features(image), top_count=1)[0] +print(best_match) +``` \ No newline at end of file diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 4317c31..9a2936a 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ -from .clip_interrogator import Interrogator, Config +from .clip_interrogator import Config, Interrogator, LabelTable, load_list -__version__ = '0.5.4' +__version__ = '0.5.5' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 284ef94..5d936fe 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -29,20 +29,20 @@ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/' @dataclass class Config: # models can optionally be passed in directly - blip_model: BLIP_Decoder = None + blip_model: Optional[BLIP_Decoder] = None clip_model = None clip_preprocess = None # blip settings blip_image_eval_size: int = 384 blip_max_length: int = 32 - blip_model_type: str = 'large' # choose between 'base' or 'large' + blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None blip_num_beams: int = 8 blip_offload: bool = False # clip settings clip_model_name: str = 'ViT-L-14/openai' - clip_model_path: str = None + clip_model_path: Optional[str] = None clip_offload: bool = False # interrogator settings @@ -68,7 +68,7 @@ class Interrogator(): self.blip_offloaded = True self.clip_offloaded = True - if config.blip_model is None: + if config.blip_model is None and config.blip_model_type: if not config.quiet: print("Loading BLIP model...") blip_path = os.path.dirname(inspect.getfile(blip_decoder)) @@ -121,17 +121,17 @@ class Interrogator(): trending_list.extend(["featured on "+site for site in sites]) trending_list.extend([site+" contest winner" for site in sites]) - raw_artists = _load_list(config.data_path, 'artists.txt') + raw_artists = load_list(config.data_path, 'artists.txt') artists = [f"by {a}" for a in raw_artists] artists.extend([f"inspired by {a}" for a in raw_artists]) self._prepare_clip() - self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) - self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) - self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) - self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) - self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) - self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config) + self.artists = LabelTable(artists, "artists", self) + self.flavors = LabelTable(load_list(config.data_path, 'flavors.txt'), "flavors", self) + self.mediums = LabelTable(load_list(config.data_path, 'mediums.txt'), "mediums", self) + self.movements = LabelTable(load_list(config.data_path, 'movements.txt'), "movements", self) + self.trendings = LabelTable(trending_list, "trendings", self) + self.negative = LabelTable(load_list(config.data_path, 'negative.txt'), "negative", self) end_time = time.time() if not config.quiet: @@ -183,6 +183,7 @@ class Interrogator(): return best_prompt def generate_caption(self, pil_image: Image) -> str: + assert self.blip_model is not None, "No BLIP model loaded." self._prepare_blip() size = self.config.blip_image_eval_size @@ -310,13 +311,14 @@ class Interrogator(): class LabelTable(): - def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): + def __init__(self, labels:List[str], desc:str, ci: Interrogator): + clip_model, config = ci.clip_model, ci.config self.chunk_size = config.chunk_size self.config = config self.device = config.device self.embeds = [] self.labels = labels - self.tokenize = tokenize + self.tokenize = ci.tokenize hash = hashlib.sha256(",".join(labels).encode()).hexdigest() sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_') @@ -423,11 +425,6 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet progress.update(len(chunk)) progress.close() -def _load_list(data_path: str, filename: str) -> List[str]: - with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: - items = [line.strip() for line in f.readlines()] - return items - def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: m = LabelTable([], None, None, None, config) for table in tables: @@ -447,3 +444,11 @@ def _truncate_to_fit(text: str, tokenize) -> str: break new_text += ', ' + part return new_text + +def load_list(data_path: str, filename: Optional[str] = None) -> List[str]: + """Load a list of strings from a file.""" + if filename is not None: + data_path = os.path.join(data_path, filename) + with open(data_path, 'r', encoding='utf-8', errors='replace') as f: + items = [line.strip() for line in f.readlines()] + return items diff --git a/setup.py b/setup.py index 81fcfef..f2806e0 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name="clip-interrogator", - version="0.5.4", + version="0.5.5", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',