@ -15,7 +15,7 @@ from PIL import Image
from torchvision import transforms
from torchvision import transforms
from torchvision . transforms . functional import InterpolationMode
from torchvision . transforms . functional import InterpolationMode
from tqdm import tqdm
from tqdm import tqdm
from typing import List
from typing import List , Union
BLIP_MODELS = {
BLIP_MODELS = {
" base " : " https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth " ,
" base " : " https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth " ,
@ -64,7 +64,7 @@ class Config:
)
)
chunk_size : int = 2048 # batch size for CLIP, use smaller for lower VRAM
chunk_size : int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , " data " )
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , " data " )
device : str = (
device : Union [ str , torch . device ] = (
" mps "
" mps "
if torch . backends . mps . is_available ( )
if torch . backends . mps . is_available ( )
else " cuda "
else " cuda "
@ -89,6 +89,7 @@ class Interrogator:
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , " configs " )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , " configs " )
med_config = os . path . join ( configs_path , " med_config.json " )
med_config = os . path . join ( configs_path , " med_config.json " )
blip_model = blip_decoder (
blip_model = blip_decoder (
pretrained = BLIP_MODELS [ config . blip_model_type ] ,
pretrained = BLIP_MODELS [ config . blip_model_type ] ,
image_size = config . blip_image_eval_size ,
image_size = config . blip_image_eval_size ,
@ -137,7 +138,14 @@ class Interrogator:
) = open_clip . create_model_and_transforms (
) = open_clip . create_model_and_transforms (
clip_model_name ,
clip_model_name ,
pretrained = clip_model_pretrained_name ,
pretrained = clip_model_pretrained_name ,
precision = " fp16 " if config . device == " cuda " else " fp32 " ,
precision = " fp16 "
if (
config . device . type
if isinstance ( config . device , torch . device )
else config . device
)
== " cuda "
else " fp32 " ,
device = " cpu " ,
device = " cpu " ,
jit = False ,
jit = False ,
cache_dir = config . clip_model_path ,
cache_dir = config . clip_model_path ,
@ -480,7 +488,9 @@ class Interrogator:
)
)
fast_prompt = self . _interrogate_fast ( caption , image_features , max_flavours )
fast_prompt = self . _interrogate_fast ( caption , image_features , max_flavours )
classic_prompt = self . interrogate_classic ( caption , image_features , max_flavours )
classic_prompt = self . _interrogate_classic (
caption , image_features , max_flavours
)
candidates = [ caption , classic_prompt , fast_prompt , best_prompt ]
candidates = [ caption , classic_prompt , fast_prompt , best_prompt ]
return candidates [ np . argmax ( self . similarities ( image_features , candidates ) ) ]
return candidates [ np . argmax ( self . similarities ( image_features , candidates ) ) ]