Browse Source

Expose LabelTable and load_list and give example in README how they can be used to rank your own list of terms.

pull/69/head v0.5.5
pharmapsychotic 2 years ago
parent
commit
ac74904908
  1. 16
      README.md
  2. 4
      clip_interrogator/__init__.py
  3. 41
      clip_interrogator/clip_interrogator.py
  4. 2
      setup.py

16
README.md

@ -40,7 +40,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator
pip install clip-interrogator==0.5.4
pip install clip-interrogator==0.5.5
```
You can then use it in your script
@ -67,3 +67,17 @@ The `Config` object lets you configure CLIP Interrogator's processing.
On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB.
See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes.
## Ranking against your own list of terms
```python
from clip_interrogator import Config, Interrogator, LabelTable, load_list
from PIL import Image
ci = Interrogator(Config(blip_model_type=None))
image = Image.open(image_path).convert('RGB')
table = LabelTable(load_list('terms.txt'), 'terms', ci)
best_match = table.rank(ci.image_to_features(image), top_count=1)[0]
print(best_match)
```

4
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config
from .clip_interrogator import Config, Interrogator, LabelTable, load_list
__version__ = '0.5.4'
__version__ = '0.5.5'
__author__ = 'pharmapsychotic'

41
clip_interrogator/clip_interrogator.py

@ -29,20 +29,20 @@ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@dataclass
class Config:
# models can optionally be passed in directly
blip_model: BLIP_Decoder = None
blip_model: Optional[BLIP_Decoder] = None
clip_model = None
clip_preprocess = None
# blip settings
blip_image_eval_size: int = 384
blip_max_length: int = 32
blip_model_type: str = 'large' # choose between 'base' or 'large'
blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None
blip_num_beams: int = 8
blip_offload: bool = False
# clip settings
clip_model_name: str = 'ViT-L-14/openai'
clip_model_path: str = None
clip_model_path: Optional[str] = None
clip_offload: bool = False
# interrogator settings
@ -68,7 +68,7 @@ class Interrogator():
self.blip_offloaded = True
self.clip_offloaded = True
if config.blip_model is None:
if config.blip_model is None and config.blip_model_type:
if not config.quiet:
print("Loading BLIP model...")
blip_path = os.path.dirname(inspect.getfile(blip_decoder))
@ -121,17 +121,17 @@ class Interrogator():
trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites])
raw_artists = _load_list(config.data_path, 'artists.txt')
raw_artists = load_list(config.data_path, 'artists.txt')
artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists])
self._prepare_clip()
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)
self.artists = LabelTable(artists, "artists", self)
self.flavors = LabelTable(load_list(config.data_path, 'flavors.txt'), "flavors", self)
self.mediums = LabelTable(load_list(config.data_path, 'mediums.txt'), "mediums", self)
self.movements = LabelTable(load_list(config.data_path, 'movements.txt'), "movements", self)
self.trendings = LabelTable(trending_list, "trendings", self)
self.negative = LabelTable(load_list(config.data_path, 'negative.txt'), "negative", self)
end_time = time.time()
if not config.quiet:
@ -183,6 +183,7 @@ class Interrogator():
return best_prompt
def generate_caption(self, pil_image: Image) -> str:
assert self.blip_model is not None, "No BLIP model loaded."
self._prepare_blip()
size = self.config.blip_image_eval_size
@ -310,13 +311,14 @@ class Interrogator():
class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
def __init__(self, labels:List[str], desc:str, ci: Interrogator):
clip_model, config = ci.clip_model, ci.config
self.chunk_size = config.chunk_size
self.config = config
self.device = config.device
self.embeds = []
self.labels = labels
self.tokenize = tokenize
self.tokenize = ci.tokenize
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_')
@ -423,11 +425,6 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
progress.update(len(chunk))
progress.close()
def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, None, config)
for table in tables:
@ -447,3 +444,11 @@ def _truncate_to_fit(text: str, tokenize) -> str:
break
new_text += ', ' + part
return new_text
def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
"""Load a list of strings from a file."""
if filename is not None:
data_path = os.path.join(data_path, filename)
with open(data_path, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup(
name="clip-interrogator",
version="0.5.4",
version="0.5.5",
license='MIT',
author='pharmapsychotic',
author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save