diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index f3f3d6c..a634436 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -16,6 +16,10 @@ from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from typing import List +BLIP_MODELS = { + 'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', + 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' +} @dataclass class Config: @@ -27,7 +31,7 @@ class Config: # blip settings blip_image_eval_size: int = 384 blip_max_length: int = 32 - blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' + blip_model_type: str = 'large' # choose between 'base' or 'large' blip_num_beams: int = 8 blip_offload: bool = False @@ -39,11 +43,10 @@ class Config: cache_path: str = 'cache' chunk_size: int = 2048 data_path: str = os.path.join(os.path.dirname(__file__), 'data') - device: str = 'cuda' if torch.cuda.is_available() else 'cpu' + device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") flavor_intermediate_count: int = 2048 quiet: bool = False # when quiet progress bars are not shown - class Interrogator(): def __init__(self, config: Config): self.config = config @@ -56,9 +59,9 @@ class Interrogator(): configs_path = os.path.join(os.path.dirname(blip_path), 'configs') med_config = os.path.join(configs_path, 'med_config.json') blip_model = blip_decoder( - pretrained=config.blip_model_url, + pretrained=BLIP_MODELS[config.blip_model_type], image_size=config.blip_image_eval_size, - vit='large', + vit=config.blip_model_type, med_config=med_config ) blip_model.eval()