diff --git a/README.md b/README.md index 7e029c1..ca819c7 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ pip3 install torch torchvision torchaudio --extra-index-url https://download.pyt # install blip and clip-interrogator pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip -pip install clip-interrogator +pip install clip-interrogator==0.3.2 ``` You can then use it in your script diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 6470b3b..057cf78 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ from .clip_interrogator import Interrogator, Config -__version__ = '0.3.1' +__version__ = '0.3.2' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 67fe5b8..a978e43 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -9,7 +9,7 @@ import time import torch from dataclasses import dataclass -from models.blip import blip_decoder +from models.blip import blip_decoder, BLIP_Decoder from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode @@ -20,7 +20,7 @@ from typing import List @dataclass class Config: # models can optionally be passed in directly - blip_model = None + blip_model: BLIP_Decoder = None clip_model = None clip_preprocess = None @@ -256,8 +256,6 @@ class LabelTable(): if data.get('hash') == hash: self.labels = data['labels'] self.embeds = data['embeds'] - if self.device == 'cpu': - self.embeds = [e.astype(np.float32) for e in self.embeds] except Exception as e: print(f"Error loading cached table {desc}: {e}") @@ -281,6 +279,9 @@ class LabelTable(): "hash": hash, "model": config.clip_model_name }, f) + + if self.device == 'cpu' or self.device == torch.device('cpu'): + self.embeds = [e.astype(np.float32) for e in self.embeds] def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: top_count = min(top_count, len(text_embeds)) diff --git a/run_cli.py b/run_cli.py index 8d02f69..fdbba04 100755 --- a/run_cli.py +++ b/run_cli.py @@ -20,6 +20,7 @@ def inference(ci, image, mode): def main(): parser = argparse.ArgumentParser() parser.add_argument('-c', '--clip', default='ViT-L-14/openai', help='name of CLIP model to use') + parser.add_argument('-d', '--device', default='auto', help='device to use (auto, cuda or cpu)') parser.add_argument('-f', '--folder', help='path to folder of images') parser.add_argument('-i', '--image', help='image file or url') parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') @@ -41,9 +42,12 @@ def main(): exit(1) # select device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if not torch.cuda.is_available(): - print("CUDA is not available, using CPU. Warning: this will be very slow!") + if args.device == 'auto': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if not torch.cuda.is_available(): + print("CUDA is not available, using CPU. Warning: this will be very slow!") + else: + device = torch.device(args.device) # generate a nice prompt config = Config(device=device, clip_model_name=args.clip) diff --git a/setup.py b/setup.py index 12f2d02..19525b5 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name="clip-interrogator", - version="0.3.1", + version="0.3.2", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',