Browse Source

Expose LabelTable and load_list and give example in README how they can be used to rank your own list of terms.

pull/69/head v0.5.5
pharmapsychotic 2 years ago
parent
commit
ac74904908
  1. 16
      README.md
  2. 4
      clip_interrogator/__init__.py
  3. 41
      clip_interrogator/clip_interrogator.py
  4. 2
      setup.py

16
README.md

@ -40,7 +40,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator # install clip-interrogator
pip install clip-interrogator==0.5.4 pip install clip-interrogator==0.5.5
``` ```
You can then use it in your script You can then use it in your script
@ -67,3 +67,17 @@ The `Config` object lets you configure CLIP Interrogator's processing.
On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB. On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB.
See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes. See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes.
## Ranking against your own list of terms
```python
from clip_interrogator import Config, Interrogator, LabelTable, load_list
from PIL import Image
ci = Interrogator(Config(blip_model_type=None))
image = Image.open(image_path).convert('RGB')
table = LabelTable(load_list('terms.txt'), 'terms', ci)
best_match = table.rank(ci.image_to_features(image), top_count=1)[0]
print(best_match)
```

4
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config from .clip_interrogator import Config, Interrogator, LabelTable, load_list
__version__ = '0.5.4' __version__ = '0.5.5'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

41
clip_interrogator/clip_interrogator.py

@ -29,20 +29,20 @@ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@dataclass @dataclass
class Config: class Config:
# models can optionally be passed in directly # models can optionally be passed in directly
blip_model: BLIP_Decoder = None blip_model: Optional[BLIP_Decoder] = None
clip_model = None clip_model = None
clip_preprocess = None clip_preprocess = None
# blip settings # blip settings
blip_image_eval_size: int = 384 blip_image_eval_size: int = 384
blip_max_length: int = 32 blip_max_length: int = 32
blip_model_type: str = 'large' # choose between 'base' or 'large' blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None
blip_num_beams: int = 8 blip_num_beams: int = 8
blip_offload: bool = False blip_offload: bool = False
# clip settings # clip settings
clip_model_name: str = 'ViT-L-14/openai' clip_model_name: str = 'ViT-L-14/openai'
clip_model_path: str = None clip_model_path: Optional[str] = None
clip_offload: bool = False clip_offload: bool = False
# interrogator settings # interrogator settings
@ -68,7 +68,7 @@ class Interrogator():
self.blip_offloaded = True self.blip_offloaded = True
self.clip_offloaded = True self.clip_offloaded = True
if config.blip_model is None: if config.blip_model is None and config.blip_model_type:
if not config.quiet: if not config.quiet:
print("Loading BLIP model...") print("Loading BLIP model...")
blip_path = os.path.dirname(inspect.getfile(blip_decoder)) blip_path = os.path.dirname(inspect.getfile(blip_decoder))
@ -121,17 +121,17 @@ class Interrogator():
trending_list.extend(["featured on "+site for site in sites]) trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites]) trending_list.extend([site+" contest winner" for site in sites])
raw_artists = _load_list(config.data_path, 'artists.txt') raw_artists = load_list(config.data_path, 'artists.txt')
artists = [f"by {a}" for a in raw_artists] artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists]) artists.extend([f"inspired by {a}" for a in raw_artists])
self._prepare_clip() self._prepare_clip()
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) self.artists = LabelTable(artists, "artists", self)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) self.flavors = LabelTable(load_list(config.data_path, 'flavors.txt'), "flavors", self)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) self.mediums = LabelTable(load_list(config.data_path, 'mediums.txt'), "mediums", self)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) self.movements = LabelTable(load_list(config.data_path, 'movements.txt'), "movements", self)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) self.trendings = LabelTable(trending_list, "trendings", self)
self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config) self.negative = LabelTable(load_list(config.data_path, 'negative.txt'), "negative", self)
end_time = time.time() end_time = time.time()
if not config.quiet: if not config.quiet:
@ -183,6 +183,7 @@ class Interrogator():
return best_prompt return best_prompt
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
assert self.blip_model is not None, "No BLIP model loaded."
self._prepare_blip() self._prepare_blip()
size = self.config.blip_image_eval_size size = self.config.blip_image_eval_size
@ -310,13 +311,14 @@ class Interrogator():
class LabelTable(): class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): def __init__(self, labels:List[str], desc:str, ci: Interrogator):
clip_model, config = ci.clip_model, ci.config
self.chunk_size = config.chunk_size self.chunk_size = config.chunk_size
self.config = config self.config = config
self.device = config.device self.device = config.device
self.embeds = [] self.embeds = []
self.labels = labels self.labels = labels
self.tokenize = tokenize self.tokenize = ci.tokenize
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_') sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_')
@ -423,11 +425,6 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
progress.update(len(chunk)) progress.update(len(chunk))
progress.close() progress.close()
def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, None, config) m = LabelTable([], None, None, None, config)
for table in tables: for table in tables:
@ -447,3 +444,11 @@ def _truncate_to_fit(text: str, tokenize) -> str:
break break
new_text += ', ' + part new_text += ', ' + part
return new_text return new_text
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
"""Load a list of strings from a file."""
if filename is not None:
data_path = os.path.join(data_path, filename)
with open(data_path, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.5.4", version="0.5.5",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save