|
|
@ -16,6 +16,10 @@ from torchvision.transforms.functional import InterpolationMode |
|
|
|
from tqdm import tqdm |
|
|
|
from tqdm import tqdm |
|
|
|
from typing import List |
|
|
|
from typing import List |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BLIP_MODELS = { |
|
|
|
|
|
|
|
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', |
|
|
|
|
|
|
|
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
|
@dataclass |
|
|
|
class Config: |
|
|
|
class Config: |
|
|
@ -27,7 +31,7 @@ class Config: |
|
|
|
# blip settings |
|
|
|
# blip settings |
|
|
|
blip_image_eval_size: int = 384 |
|
|
|
blip_image_eval_size: int = 384 |
|
|
|
blip_max_length: int = 32 |
|
|
|
blip_max_length: int = 32 |
|
|
|
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
|
|
|
blip_model_type: str = 'large' # choose between 'base' or 'large' |
|
|
|
blip_num_beams: int = 8 |
|
|
|
blip_num_beams: int = 8 |
|
|
|
blip_offload: bool = False |
|
|
|
blip_offload: bool = False |
|
|
|
|
|
|
|
|
|
|
@ -39,11 +43,10 @@ class Config: |
|
|
|
cache_path: str = 'cache' |
|
|
|
cache_path: str = 'cache' |
|
|
|
chunk_size: int = 2048 |
|
|
|
chunk_size: int = 2048 |
|
|
|
data_path: str = os.path.join(os.path.dirname(__file__), 'data') |
|
|
|
data_path: str = os.path.join(os.path.dirname(__file__), 'data') |
|
|
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
flavor_intermediate_count: int = 2048 |
|
|
|
flavor_intermediate_count: int = 2048 |
|
|
|
quiet: bool = False # when quiet progress bars are not shown |
|
|
|
quiet: bool = False # when quiet progress bars are not shown |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Interrogator(): |
|
|
|
class Interrogator(): |
|
|
|
def __init__(self, config: Config): |
|
|
|
def __init__(self, config: Config): |
|
|
|
self.config = config |
|
|
|
self.config = config |
|
|
@ -56,9 +59,9 @@ class Interrogator(): |
|
|
|
configs_path = os.path.join(os.path.dirname(blip_path), 'configs') |
|
|
|
configs_path = os.path.join(os.path.dirname(blip_path), 'configs') |
|
|
|
med_config = os.path.join(configs_path, 'med_config.json') |
|
|
|
med_config = os.path.join(configs_path, 'med_config.json') |
|
|
|
blip_model = blip_decoder( |
|
|
|
blip_model = blip_decoder( |
|
|
|
pretrained=config.blip_model_url, |
|
|
|
pretrained=BLIP_MODELS[config.blip_model_type], |
|
|
|
image_size=config.blip_image_eval_size, |
|
|
|
image_size=config.blip_image_eval_size, |
|
|
|
vit='large', |
|
|
|
vit=config.blip_model_type, |
|
|
|
med_config=med_config |
|
|
|
med_config=med_config |
|
|
|
) |
|
|
|
) |
|
|
|
blip_model.eval() |
|
|
|
blip_model.eval() |
|
|
|