diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 30545a8..c47aa93 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -15,7 +15,7 @@ from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm -from typing import List +from typing import List, Union BLIP_MODELS = { "base": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth", @@ -64,7 +64,7 @@ class Config: ) chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM data_path: str = os.path.join(os.path.dirname(__file__), "data") - device: str = ( + device: Union[str, torch.device] = ( "mps" if torch.backends.mps.is_available() else "cuda" @@ -89,6 +89,7 @@ class Interrogator: blip_path = os.path.dirname(inspect.getfile(blip_decoder)) configs_path = os.path.join(os.path.dirname(blip_path), "configs") med_config = os.path.join(configs_path, "med_config.json") + blip_model = blip_decoder( pretrained=BLIP_MODELS[config.blip_model_type], image_size=config.blip_image_eval_size, @@ -137,7 +138,14 @@ class Interrogator: ) = open_clip.create_model_and_transforms( clip_model_name, pretrained=clip_model_pretrained_name, - precision="fp16" if config.device == "cuda" else "fp32", + precision="fp16" + if ( + config.device.type + if isinstance(config.device, torch.device) + else config.device + ) + == "cuda" + else "fp32", device="cpu", jit=False, cache_dir=config.clip_model_path, @@ -480,7 +488,9 @@ class Interrogator: ) fast_prompt = self._interrogate_fast(caption, image_features, max_flavours) - classic_prompt = self.interrogate_classic(caption, image_features, max_flavours) + classic_prompt = self._interrogate_classic( + caption, image_features, max_flavours + ) candidates = [caption, classic_prompt, fast_prompt, best_prompt] return candidates[np.argmax(self.similarities(image_features, candidates))]