@ -18,24 +18,24 @@ from tqdm import tqdm
from typing import List
from typing import List
BLIP_MODELS = {
BLIP_MODELS = {
' base ' : ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' ,
" base " : " https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth" ,
' large ' : ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
" large " : " https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth" ,
}
}
CACHE_URLS_VITL = [
CACHE_URLS_VITL = [
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl" ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl" ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl" ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl" ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl" ,
]
]
CACHE_URLS_VITH = [
CACHE_URLS_VITH = [
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl" ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl" ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl" ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl" ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl" ,
]
]
@ -49,40 +49,51 @@ class Config:
# blip settings
# blip settings
blip_image_eval_size : int = 384
blip_image_eval_size : int = 384
blip_max_length : int = 32
blip_max_length : int = 32
blip_model_type : str = ' large ' # choose between 'base' or 'large'
blip_model_type : str = " large " # choose between 'base' or 'large'
blip_num_beams : int = 8
blip_num_beams : int = 8
blip_offload : bool = False
blip_offload : bool = False
# clip settings
# clip settings
clip_model_name : str = ' ViT-L-14/openai '
clip_model_name : str = " ViT-L-14/openai "
clip_model_path : str = None
clip_model_path : str = None
# interrogator settings
# interrogator settings
cache_path : str = ' cache ' # path to store cached text embeddings
cache_path : str = " cache " # path to store cached text embeddings
download_cache : bool = True # when true, cached embeds are downloaded from huggingface
download_cache : bool = (
True # when true, cached embeds are downloaded from huggingface
)
chunk_size : int = 2048 # batch size for CLIP, use smaller for lower VRAM
chunk_size : int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , ' data ' )
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , " data " )
device : str = ( " mps " if torch . backends . mps . is_available ( ) else " cuda " if torch . cuda . is_available ( ) else " cpu " )
device : str = (
" mps "
if torch . backends . mps . is_available ( )
else " cuda "
if torch . cuda . is_available ( )
else " cpu "
)
flavor_intermediate_count : int = 2048
flavor_intermediate_count : int = 2048
quiet : bool = False # when quiet progress bars are not shown
quiet : bool = False # when quiet progress bars are not shown
class Interrogator ( ) :
class Interrogator :
def __init__ ( self , config : Config ) :
def __init__ ( self , config : Config ) :
self . config = config
self . config = config
self . device = config . device
self . device = config . device
# Record which model is on the target device
self . blip_loaded = True
# Load BLIP model (to intended device)
if config . blip_model is None :
if config . blip_model is None :
if not config . quiet :
if not config . quiet :
print ( " Loading BLIP model... " )
print ( " Loading BLIP model... " )
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , ' configs ' )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , " configs " )
med_config = os . path . join ( configs_path , ' med_config.json ' )
med_config = os . path . join ( configs_path , " med_config.json " )
blip_model = blip_decoder (
blip_model = blip_decoder (
pretrained = BLIP_MODELS [ config . blip_model_type ] ,
pretrained = BLIP_MODELS [ config . blip_model_type ] ,
image_size = config . blip_image_eval_size ,
image_size = config . blip_image_eval_size ,
vit = config . blip_model_type ,
vit = config . blip_model_type ,
med_config = med_config
med_config = med_config ,
)
)
blip_model . eval ( )
blip_model . eval ( )
blip_model = blip_model . to ( config . device )
blip_model = blip_model . to ( config . device )
@ -90,12 +101,13 @@ class Interrogator():
else :
else :
self . blip_model = config . blip_model
self . blip_model = config . blip_model
# Load CLIP (to CPU)
self . load_clip_model ( )
self . load_clip_model ( )
def download_cache ( self , clip_model_name : str ) :
def download_cache ( self , clip_model_name : str ) :
if clip_model_name == ' ViT-L-14/openai ' :
if clip_model_name == " ViT-L-14/openai " :
cache_urls = CACHE_URLS_VITL
cache_urls = CACHE_URLS_VITL
elif clip_model_name == ' ViT-H-14/laion2b_s32b_b79k ' :
elif clip_model_name == " ViT-H-14/laion2b_s32b_b79k " :
cache_urls = CACHE_URLS_VITH
cache_urls = CACHE_URLS_VITH
else :
else :
# text embeddings will be precomputed and cached locally
# text embeddings will be precomputed and cached locally
@ -103,7 +115,7 @@ class Interrogator():
os . makedirs ( self . config . cache_path , exist_ok = True )
os . makedirs ( self . config . cache_path , exist_ok = True )
for url in cache_urls :
for url in cache_urls :
filepath = os . path . join ( self . config . cache_path , url . split ( ' / ' ) [ - 1 ] )
filepath = os . path . join ( self . config . cache_path , url . split ( " / " ) [ - 1 ] )
if not os . path . exists ( filepath ) :
if not os . path . exists ( filepath ) :
_download_file ( url , filepath , quiet = self . config . quiet )
_download_file ( url , filepath , quiet = self . config . quiet )
@ -115,40 +127,93 @@ class Interrogator():
if not config . quiet :
if not config . quiet :
print ( " Loading CLIP model... " )
print ( " Loading CLIP model... " )
clip_model_name , clip_model_pretrained_name = config . clip_model_name . split ( ' / ' , 2 )
clip_model_name , clip_model_pretrained_name = config . clip_model_name . split (
self . clip_model , _ , self . clip_preprocess = open_clip . create_model_and_transforms (
" / " , 2
)
(
self . clip_model ,
_ ,
self . clip_preprocess ,
) = open_clip . create_model_and_transforms (
clip_model_name ,
clip_model_name ,
pretrained = clip_model_pretrained_name ,
pretrained = clip_model_pretrained_name ,
precision = ' fp16 ' if config . device == ' cuda ' else ' fp32 ' ,
precision = " fp16 " if config . device == " cuda " else " fp32 " ,
device = config . device ,
device = " cpu " ,
jit = False ,
jit = False ,
cache_dir = config . clip_model_path
cache_dir = config . clip_model_path ,
)
)
self . clip_model . to ( config . device ) . eval ( )
self . clip_model . eval ( )
else :
else :
self . clip_model = config . clip_model
self . clip_model = config . clip_model
self . clip_preprocess = config . clip_preprocess
self . clip_preprocess = config . clip_preprocess
self . tokenize = open_clip . get_tokenizer ( clip_model_name )
self . tokenize = open_clip . get_tokenizer ( clip_model_name )
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribble ' , ' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' , ' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
sites = [
" Artstation " ,
" behance " ,
" cg society " ,
" cgsociety " ,
" deviantart " ,
" dribble " ,
" flickr " ,
" instagram " ,
" pexels " ,
" pinterest " ,
" pixabay " ,
" pixiv " ,
" polycount " ,
" reddit " ,
" shutterstock " ,
" tumblr " ,
" unsplash " ,
" zbrush central " ,
]
trending_list = [ site for site in sites ]
trending_list = [ site for site in sites ]
trending_list . extend ( [ " trending on " + site for site in sites ] )
trending_list . extend ( [ " trending on " + site for site in sites ] )
trending_list . extend ( [ " featured on " + site for site in sites ] )
trending_list . extend ( [ " featured on " + site for site in sites ] )
trending_list . extend ( [ site + " contest winner " for site in sites ] )
trending_list . extend ( [ site + " contest winner " for site in sites ] )
raw_artists = _load_list ( config . data_path , ' artists.txt ' )
raw_artists = _load_list ( config . data_path , " artists.txt " )
artists = [ f " by { a } " for a in raw_artists ]
artists = [ f " by { a } " for a in raw_artists ]
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
if config . download_cache :
if config . download_cache :
self . download_cache ( config . clip_model_name )
self . download_cache ( config . clip_model_name )
self . artists = LabelTable ( artists , " artists " , self . clip_model , self . tokenize , config )
self . artists = LabelTable (
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , self . tokenize , config )
artists , " artists " , self . clip_model , self . tokenize , config
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
)
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . flavors = LabelTable (
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
_load_list ( config . data_path , " flavors.txt " ) ,
self . negative = LabelTable ( _load_list ( config . data_path , ' negative.txt ' ) , " negative " , self . clip_model , self . tokenize , config )
" flavors " ,
self . clip_model ,
self . tokenize ,
config ,
)
self . mediums = LabelTable (
_load_list ( config . data_path , " mediums.txt " ) ,
" mediums " ,
self . clip_model ,
self . tokenize ,
config ,
)
self . movements = LabelTable (
_load_list ( config . data_path , " movements.txt " ) ,
" movements " ,
self . clip_model ,
self . tokenize ,
config ,
)
self . trendings = LabelTable (
trending_list , " trendings " , self . clip_model , self . tokenize , config
)
self . negative = LabelTable (
_load_list ( config . data_path , " negative.txt " ) ,
" negative " ,
self . clip_model ,
self . tokenize ,
config ,
)
end_time = time . time ( )
end_time = time . time ( )
if not config . quiet :
if not config . quiet :
@ -158,16 +223,18 @@ class Interrogator():
self ,
self ,
image_features : torch . Tensor ,
image_features : torch . Tensor ,
phrases : List [ str ] ,
phrases : List [ str ] ,
best_prompt : str = " " ,
best_prompt : str = " " ,
best_sim : float = 0 ,
best_sim : float = 0 ,
min_count : int = 8 ,
min_count : int = 8 ,
max_count : int = 32 ,
max_count : int = 32 ,
desc = " Chaining " ,
desc = " Chaining " ,
reverse : bool = False
reverse : bool = False ,
) - > str :
) - > str :
phrases = set ( phrases )
phrases = set ( phrases )
if not best_prompt :
if not best_prompt :
best_prompt = self . rank_top ( image_features , [ f for f in phrases ] , reverse = reverse )
best_prompt = self . rank_top (
image_features , [ f for f in phrases ] , reverse = reverse
)
best_sim = self . similarity ( image_features , best_prompt )
best_sim = self . similarity ( image_features , best_prompt )
phrases . remove ( best_prompt )
phrases . remove ( best_prompt )
curr_prompt , curr_sim = best_prompt , best_sim
curr_prompt , curr_sim = best_prompt , best_sim
@ -187,8 +254,12 @@ class Interrogator():
return False
return False
for idx in tqdm ( range ( max_count ) , desc = desc , disable = self . config . quiet ) :
for idx in tqdm ( range ( max_count ) , desc = desc , disable = self . config . quiet ) :
best = self . rank_top ( image_features , [ f " { curr_prompt } , { f } " for f in phrases ] , reverse = reverse )
best = self . rank_top (
flave = best [ len ( curr_prompt ) + 2 : ]
image_features ,
[ f " { curr_prompt } , { f } " for f in phrases ] ,
reverse = reverse ,
)
flave = best [ len ( curr_prompt ) + 2 : ]
if not check ( flave , idx ) :
if not check ( flave , idx ) :
break
break
if _prompt_at_max_len ( curr_prompt , self . tokenize ) :
if _prompt_at_max_len ( curr_prompt , self . tokenize ) :
@ -201,11 +272,22 @@ class Interrogator():
if self . config . blip_offload :
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( self . device )
self . blip_model = self . blip_model . to ( self . device )
size = self . config . blip_image_eval_size
size = self . config . blip_image_eval_size
gpu_image = transforms . Compose ( [
gpu_image = (
transforms . Resize ( ( size , size ) , interpolation = InterpolationMode . BICUBIC ) ,
transforms . Compose (
[
transforms . Resize (
( size , size ) , interpolation = InterpolationMode . BICUBIC
) ,
transforms . ToTensor ( ) ,
transforms . ToTensor ( ) ,
transforms . Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) )
transforms . Normalize (
] ) ( pil_image ) . unsqueeze ( 0 ) . to ( self . device )
( 0.48145466 , 0.4578275 , 0.40821073 ) ,
( 0.26862954 , 0.26130258 , 0.27577711 ) ,
) ,
]
) ( pil_image )
. unsqueeze ( 0 )
. to ( self . device )
)
with torch . no_grad ( ) :
with torch . no_grad ( ) :
caption = self . blip_model . generate (
caption = self . blip_model . generate (
@ -213,7 +295,7 @@ class Interrogator():
sample = False ,
sample = False ,
num_beams = self . config . blip_num_beams ,
num_beams = self . config . blip_num_beams ,
max_length = self . config . blip_max_length ,
max_length = self . config . blip_max_length ,
min_length = 5
min_length = 5 ,
)
)
if self . config . blip_offload :
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( " cpu " )
self . blip_model = self . blip_model . to ( " cpu " )
@ -226,17 +308,65 @@ class Interrogator():
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
return image_features
return image_features
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
def _first_bit ( self , image : Image ) - > ( str , torch . Tensor ) :
""" Classic mode creates a prompt in a standard format first describing the image,
if self . blip_loaded :
then listing the artist , trending , movement , and flavor text modifiers . """
caption = self . generate_caption ( image )
caption = self . generate_caption ( image )
# Move BLIP to RAM
self . blip_model . to ( " cpu " )
# Move CLIP to intended device
self . clip_model . to ( self . device )
image_features = self . image_to_features ( image )
else : # CLIP is loaded
image_features = self . image_to_features ( image )
image_features = self . image_to_features ( image )
# Move CLIP to RAM
self . clip_model . to ( " cpu " )
# Move BLIP to intended device
self . blip_model . to ( self . device )
caption = self . generate_caption ( image )
# Toggle `blip_loaded`
self . blip_loaded ^ = True
return caption , image_features
def _first_bit_batch ( self , images : list [ Image ] ) - > ( list [ str ] , list [ torch . Tensor ] ) :
image_features : list [ torch . Tensor ] = [ ]
if self . blip_loaded :
captions = [ self . generate_caption ( img ) for img in images ]
# Move BLIP to RAM
self . blip_model . to ( " cpu " )
# Move CLIP to intended device
self . clip_model . to ( self . device )
image_features = [ self . image_to_features ( img ) for img in images ]
else : # CLIP is loaded
image_features = [ self . image_to_features ( img ) for img in images ]
# Move CLIP to RAM
self . clip_model . to ( " cpu " )
# Move BLIP to intended device
self . blip_model . to ( self . device )
captions = [ self . generate_caption ( img ) for img in images ]
# Toggle `blip_loaded`
self . blip_loaded ^ = True
return captions , image_features
def _interrogate_classic (
self , caption : str , image_features : torch . Tensor , max_flavours : int = 3
) - > str :
medium = self . mediums . rank ( image_features , 1 ) [ 0 ]
medium = self . mediums . rank ( image_features , 1 ) [ 0 ]
artist = self . artists . rank ( image_features , 1 ) [ 0 ]
artist = self . artists . rank ( image_features , 1 ) [ 0 ]
trending = self . trendings . rank ( image_features , 1 ) [ 0 ]
trending = self . trendings . rank ( image_features , 1 ) [ 0 ]
movement = self . movements . rank ( image_features , 1 ) [ 0 ]
movement = self . movements . rank ( image_features , 1 ) [ 0 ]
flaves = " , " . join ( self . flavors . rank ( image_features , max_flavors ) )
flaves = " , " . join ( self . flavors . rank ( image_features , max_flavou rs ) )
if caption . startswith ( medium ) :
if caption . startswith ( medium ) :
prompt = f " { caption } { artist } , { trending } , { movement } , { flaves } "
prompt = f " { caption } { artist } , { trending } , { movement } , { flaves } "
@ -245,41 +375,138 @@ class Interrogator():
return _truncate_to_fit ( prompt , self . tokenize )
return _truncate_to_fit ( prompt , self . tokenize )
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
""" Classic mode creates a prompt in a standard format first describing the image,
then listing the artist , trending , movement , and flavor text modifiers . """
caption , image_features = self . _first_bit ( image )
return self . _interrogate_classic ( caption , image_features , max_flavors )
def interrogate_classic_batch (
self , images : list [ Image ] , max_flavors : int = 3
) - > list [ str ] :
""" Classic mode creates a prompt in a standard format first describing the image,
then listing the artist , trending , movement , and flavor text modifiers .
This function interrogates a batch of images ( more efficient than doing
it individually ) . """
captions , image_features = self . _first_bit_batch ( images )
returns : list [ str ] = [
self . _interrogate_classic ( caption , feature , max_flavors )
for caption , feature in zip ( captions , image_features )
]
return returns
def _interrogate_fast (
self , caption : str , image_features : torch . Tensor , max_flavours : int = 32
) - > str :
merged = _merge_tables (
[ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] ,
self . config ,
)
tops = merged . rank ( image_features , max_flavours )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) , self . tokenize )
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
""" Fast mode simply adds the top ranked terms after a caption. It generally results in
""" Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode , but the prompts
better similarity between generated prompt and image than classic mode , but the prompts
are less readable . """
are less readable . """
caption = self . generate_caption ( image )
caption , image_features = self . _first_bit ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
return self . _interrogate_fast ( caption , image_features , max_flavors )
tops = merged . rank ( image_features , max_flavors )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) , self . tokenize )
def interrogate_fast_batch ( self , images : list [ Image ] , max_flavors : int = 32 ) - > str :
""" Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode , but the prompts
are less readable .
This function interrogates a batch of images ( more efficient than doing
it individually ) . """
captions , image_features = self . _first_bit_batch ( images )
returns : list [ str ] = [
self . _interrogate_fast ( caption , feature , max_flavors )
for caption , feature in zip ( captions , image_features )
]
return returns
def interrogate_negative ( self , image : Image , max_flavors : int = 32 ) - > str :
def interrogate_negative ( self , image : Image , max_flavors : int = 32 ) - > str :
""" Negative mode chains together the most dissimilar terms to the image. It can be used
""" Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2. """
improve the results of generated images particularly with Stable Diffusion 2. """
image_features = self . image_to_features ( image )
if self . blip_loaded : # Move CLIP to intended device
flaves = self . flavors . rank ( image_features , self . config . flavor_intermediate_count , reverse = True )
self . blip_model . to ( " cpu " )
flaves = flaves + self . negative . labels
self . cli_model . to ( self . device )
return self . chain ( image_features , flaves , max_count = max_flavors , reverse = True , desc = " Negative chain " )
def interrogate ( self , image : Image , min_flavors : int = 8 , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
image_features = self . image_to_features ( image )
flaves = self . flavors . rank (
image_features , self . config . flavor_intermediate_count , reverse = True
)
flaves = flaves + self . negative . labels
return self . chain (
image_features ,
flaves ,
max_count = max_flavors ,
reverse = True ,
desc = " Negative chain " ,
)
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
def _interrogate (
self ,
caption : str ,
image_features : torch . Tensor ,
min_flavours : int = 8 ,
max_flavours : int = 32 ,
) - > str :
merged = _merge_tables (
[ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] ,
self . config ,
)
flaves = merged . rank ( image_features , self . config . flavor_intermediate_count )
flaves = merged . rank ( image_features , self . config . flavor_intermediate_count )
best_prompt , best_sim = caption , self . similarity ( image_features , caption )
best_prompt , best_sim = caption , self . similarity ( image_features , caption )
best_prompt = self . chain ( image_features , flaves , best_prompt , best_sim , min_count = min_flavors , max_count = max_flavors , desc = " Flavor chain " )
best_prompt = self . chain (
image_features ,
flaves ,
best_prompt ,
best_sim ,
min_count = min_flavours ,
max_count = max_flavours ,
desc = " Flavor chain " ,
)
fast_prompt = self . interrogate_fast ( image , max_flavors )
fast_prompt = self . _ interrogate_fast( caption , image_features , max_flavou rs )
classic_prompt = self . interrogate_classic ( image , max_flavors )
classic_prompt = self . interrogate_classic ( caption , image_features , max_flavou rs )
candidates = [ caption , classic_prompt , fast_prompt , best_prompt ]
candidates = [ caption , classic_prompt , fast_prompt , best_prompt ]
return candidates [ np . argmax ( self . similarities ( image_features , candidates ) ) ]
return candidates [ np . argmax ( self . similarities ( image_features , candidates ) ) ]
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] , reverse : bool = False ) - > str :
def interrogate (
self , image : Image , min_flavors : int = 8 , max_flavors : int = 32
) - > str :
caption , image_features = self . _first_bit ( image )
return self . _interrogate ( caption , image_features , min_flavors , max_flavors )
def interrogate_batch (
self , images : list [ Image ] , min_flavors : int = 8 , max_flavors : int = 32
) - > list [ str ] :
""" This function interrogates a batch of images (more efficient than doing
it individually ) . """
captions , image_features = self . _first_bit_batch ( images )
returns : list [ str ] = [
self . _interrogate ( caption , features , min_flavors , max_flavors )
for caption , features in zip ( captions , image_features )
]
return returns
def rank_top (
self , image_features : torch . Tensor , text_array : List [ str ] , reverse : bool = False
) - > str :
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
text_features = self . clip_model . encode_text ( text_tokens )
@ -297,7 +524,9 @@ class Interrogator():
similarity = text_features @ image_features . T
similarity = text_features @ image_features . T
return similarity [ 0 ] [ 0 ] . item ( )
return similarity [ 0 ] [ 0 ] . item ( )
def similarities ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > List [ float ] :
def similarities (
self , image_features : torch . Tensor , text_array : List [ str ]
) - > List [ float ] :
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
text_features = self . clip_model . encode_text ( text_tokens )
@ -306,8 +535,10 @@ class Interrogator():
return similarity . T [ 0 ] . tolist ( )
return similarity . T [ 0 ] . tolist ( )
class LabelTable ( ) :
class LabelTable :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config ) :
def __init__ (
self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config
) :
self . chunk_size = config . chunk_size
self . chunk_size = config . chunk_size
self . config = config
self . config = config
self . device = config . device
self . device = config . device
@ -320,22 +551,30 @@ class LabelTable():
cache_filepath = None
cache_filepath = None
if config . cache_path is not None and desc is not None :
if config . cache_path is not None and desc is not None :
os . makedirs ( config . cache_path , exist_ok = True )
os . makedirs ( config . cache_path , exist_ok = True )
sanitized_name = config . clip_model_name . replace ( ' / ' , ' _ ' ) . replace ( ' @ ' , ' _ ' )
sanitized_name = config . clip_model_name . replace ( " / " , " _ " ) . replace ( " @ " , " _ " )
cache_filepath = os . path . join ( config . cache_path , f " { sanitized_name } _ { desc } .pkl " )
cache_filepath = os . path . join (
config . cache_path , f " { sanitized_name } _ { desc } .pkl "
)
if desc is not None and os . path . exists ( cache_filepath ) :
if desc is not None and os . path . exists ( cache_filepath ) :
with open ( cache_filepath , ' rb ' ) as f :
with open ( cache_filepath , " rb " ) as f :
try :
try :
data = pickle . load ( f )
data = pickle . load ( f )
if data . get ( ' hash ' ) == hash :
if data . get ( " hash " ) == hash :
self . labels = data [ ' labels ' ]
self . labels = data [ " labels " ]
self . embeds = data [ ' embeds ' ]
self . embeds = data [ " embeds " ]
except Exception as e :
except Exception as e :
print ( f " Error loading cached table { desc } : { e } " )
print ( f " Error loading cached table { desc } : { e } " )
if len ( self . labels ) != len ( self . embeds ) :
if len ( self . labels ) != len ( self . embeds ) :
self . embeds = [ ]
self . embeds = [ ]
chunks = np . array_split ( self . labels , max ( 1 , len ( self . labels ) / config . chunk_size ) )
chunks = np . array_split (
for chunk in tqdm ( chunks , desc = f " Preprocessing { desc } " if desc else None , disable = self . config . quiet ) :
self . labels , max ( 1 , len ( self . labels ) / config . chunk_size )
)
for chunk in tqdm (
chunks ,
desc = f " Preprocessing { desc } " if desc else None ,
disable = self . config . quiet ,
) :
text_tokens = self . tokenize ( chunk ) . to ( self . device )
text_tokens = self . tokenize ( chunk ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = clip_model . encode_text ( text_tokens )
text_features = clip_model . encode_text ( text_tokens )
@ -345,20 +584,31 @@ class LabelTable():
self . embeds . append ( text_features [ i ] )
self . embeds . append ( text_features [ i ] )
if cache_filepath is not None :
if cache_filepath is not None :
with open ( cache_filepath , ' wb ' ) as f :
with open ( cache_filepath , " wb " ) as f :
pickle . dump ( {
pickle . dump (
{
" labels " : self . labels ,
" labels " : self . labels ,
" embeds " : self . embeds ,
" embeds " : self . embeds ,
" hash " : hash ,
" hash " : hash ,
" model " : config . clip_model_name
" model " : config . clip_model_name ,
} , f )
} ,
f ,
)
if self . device == ' cpu ' or self . device == torch . device ( ' cpu ' ) :
if self . device == " cpu " or self . device == torch . device ( " cpu " ) :
self . embeds = [ e . astype ( np . float32 ) for e in self . embeds ]
self . embeds = [ e . astype ( np . float32 ) for e in self . embeds ]
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > str :
def _rank (
self ,
image_features : torch . Tensor ,
text_embeds : torch . Tensor ,
top_count : int = 1 ,
reverse : bool = False ,
) - > str :
top_count = min ( top_count , len ( text_embeds ) )
top_count = min ( top_count , len ( text_embeds ) )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to ( self . device )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to (
self . device
)
with torch . cuda . amp . autocast ( ) :
with torch . cuda . amp . autocast ( ) :
similarity = image_features @ text_embeds . T
similarity = image_features @ text_embeds . T
if reverse :
if reverse :
@ -366,31 +616,44 @@ class LabelTable():
_ , top_labels = similarity . float ( ) . cpu ( ) . topk ( top_count , dim = - 1 )
_ , top_labels = similarity . float ( ) . cpu ( ) . topk ( top_count , dim = - 1 )
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
def rank ( self , image_features : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > List [ str ] :
def rank (
self , image_features : torch . Tensor , top_count : int = 1 , reverse : bool = False
) - > List [ str ] :
if len ( self . labels ) < = self . chunk_size :
if len ( self . labels ) < = self . chunk_size :
tops = self . _rank ( image_features , self . embeds , top_count = top_count , reverse = reverse )
tops = self . _rank (
image_features , self . embeds , top_count = top_count , reverse = reverse
)
return [ self . labels [ i ] for i in tops ]
return [ self . labels [ i ] for i in tops ]
num_chunks = int ( math . ceil ( len ( self . labels ) / self . chunk_size ) )
num_chunks = int ( math . ceil ( len ( self . labels ) / self . chunk_size ) )
keep_per_chunk = int ( self . chunk_size / num_chunks )
keep_per_chunk = int ( self . chunk_size / num_chunks )
top_labels , top_embeds = [ ] , [ ]
top_labels , top_embeds = [ ] , [ ]
for chunk_idx in tqdm ( range ( num_chunks ) , disable = self . config . quiet ) :
for chunk_idx in tqdm ( range ( num_chunks ) , disable = self . config . quiet ) :
start = chunk_idx * self . chunk_size
start = chunk_idx * self . chunk_size
stop = min ( start + self . chunk_size , len ( self . embeds ) )
stop = min ( start + self . chunk_size , len ( self . embeds ) )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk , reverse = reverse )
tops = self . _rank (
top_labels . extend ( [ self . labels [ start + i ] for i in tops ] )
image_features ,
top_embeds . extend ( [ self . embeds [ start + i ] for i in tops ] )
self . embeds [ start : stop ] ,
top_count = keep_per_chunk ,
reverse = reverse ,
)
top_labels . extend ( [ self . labels [ start + i ] for i in tops ] )
top_embeds . extend ( [ self . embeds [ start + i ] for i in tops ] )
tops = self . _rank ( image_features , top_embeds , top_count = top_count )
tops = self . _rank ( image_features , top_embeds , top_count = top_count )
return [ top_labels [ i ] for i in tops ]
return [ top_labels [ i ] for i in tops ]
def _download_file ( url : str , filepath : str , chunk_size : int = 64 * 1024 , quiet : bool = False ) :
def _download_file (
url : str , filepath : str , chunk_size : int = 64 * 1024 , quiet : bool = False
) :
r = requests . get ( url , stream = True )
r = requests . get ( url , stream = True )
file_size = int ( r . headers . get ( " Content-Length " , 0 ) )
file_size = int ( r . headers . get ( " Content-Length " , 0 ) )
filename = url . split ( " / " ) [ - 1 ]
filename = url . split ( " / " ) [ - 1 ]
progress = tqdm ( total = file_size , unit = " B " , unit_scale = True , desc = filename , disable = quiet )
progress = tqdm (
total = file_size , unit = " B " , unit_scale = True , desc = filename , disable = quiet
)
with open ( filepath , " wb " ) as f :
with open ( filepath , " wb " ) as f :
for chunk in r . iter_content ( chunk_size = chunk_size ) :
for chunk in r . iter_content ( chunk_size = chunk_size ) :
if chunk :
if chunk :
@ -398,11 +661,15 @@ def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bo
progress . update ( len ( chunk ) )
progress . update ( len ( chunk ) )
progress . close ( )
progress . close ( )
def _load_list ( data_path : str , filename : str ) - > List [ str ] :
def _load_list ( data_path : str , filename : str ) - > List [ str ] :
with open ( os . path . join ( data_path , filename ) , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
with open (
os . path . join ( data_path , filename ) , " r " , encoding = " utf-8 " , errors = " replace "
) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items
return items
def _merge_tables ( tables : List [ LabelTable ] , config : Config ) - > LabelTable :
def _merge_tables ( tables : List [ LabelTable ] , config : Config ) - > LabelTable :
m = LabelTable ( [ ] , None , None , None , config )
m = LabelTable ( [ ] , None , None , None , config )
for table in tables :
for table in tables :
@ -410,15 +677,17 @@ def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m . embeds . extend ( table . embeds )
m . embeds . extend ( table . embeds )
return m
return m
def _prompt_at_max_len ( text : str , tokenize ) - > bool :
def _prompt_at_max_len ( text : str , tokenize ) - > bool :
tokens = tokenize ( [ text ] )
tokens = tokenize ( [ text ] )
return tokens [ 0 ] [ - 1 ] != 0
return tokens [ 0 ] [ - 1 ] != 0
def _truncate_to_fit ( text : str , tokenize ) - > str :
def _truncate_to_fit ( text : str , tokenize ) - > str :
parts = text . split ( ' , ' )
parts = text . split ( " , " )
new_text = parts [ 0 ]
new_text = parts [ 0 ]
for part in parts [ 1 : ] :
for part in parts [ 1 : ] :
if _prompt_at_max_len ( new_text + part , tokenize ) :
if _prompt_at_max_len ( new_text + part , tokenize ) :
break
break
new_text + = ' , ' + part
new_text + = " , " + part
return new_text
return new_text