Browse Source

Make the BLIP model configurable, can set config.blip_model_type now to 'base' or 'large'

pull/34/merge
pharmapsychotic 2 years ago
parent
commit
99c8d45e86
  1. 13
      clip_interrogator/clip_interrogator.py

13
clip_interrogator/clip_interrogator.py

@ -16,6 +16,10 @@ from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm from tqdm import tqdm
from typing import List from typing import List
BLIP_MODELS = {
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
}
@dataclass @dataclass
class Config: class Config:
@ -27,7 +31,7 @@ class Config:
# blip settings # blip settings
blip_image_eval_size: int = 384 blip_image_eval_size: int = 384
blip_max_length: int = 32 blip_max_length: int = 32
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' blip_model_type: str = 'large' # choose between 'base' or 'large'
blip_num_beams: int = 8 blip_num_beams: int = 8
blip_offload: bool = False blip_offload: bool = False
@ -39,11 +43,10 @@ class Config:
cache_path: str = 'cache' cache_path: str = 'cache'
chunk_size: int = 2048 chunk_size: int = 2048
data_path: str = os.path.join(os.path.dirname(__file__), 'data') data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
flavor_intermediate_count: int = 2048 flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown quiet: bool = False # when quiet progress bars are not shown
class Interrogator(): class Interrogator():
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
@ -56,9 +59,9 @@ class Interrogator():
configs_path = os.path.join(os.path.dirname(blip_path), 'configs') configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
med_config = os.path.join(configs_path, 'med_config.json') med_config = os.path.join(configs_path, 'med_config.json')
blip_model = blip_decoder( blip_model = blip_decoder(
pretrained=config.blip_model_url, pretrained=BLIP_MODELS[config.blip_model_type],
image_size=config.blip_image_eval_size, image_size=config.blip_image_eval_size,
vit='large', vit=config.blip_model_type,
med_config=med_config med_config=med_config
) )
blip_model.eval() blip_model.eval()

Loading…
Cancel
Save