@ -18,24 +18,24 @@ from tqdm import tqdm
from typing import List
BLIP_MODELS = {
' base ' : ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' ,
' large ' : ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
" base " : " https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth" ,
" large " : " https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth" ,
}
CACHE_URLS_VITL = [
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl" ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl" ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl" ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl" ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl" ,
]
CACHE_URLS_VITH = [
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl' ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl" ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl" ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl" ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl" ,
" https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl" ,
]
@ -49,40 +49,51 @@ class Config:
# blip settings
blip_image_eval_size : int = 384
blip_max_length : int = 32
blip_model_type : str = ' large ' # choose between 'base' or 'large'
blip_model_type : str = " large " # choose between 'base' or 'large'
blip_num_beams : int = 8
blip_offload : bool = False
# clip settings
clip_model_name : str = ' ViT-L-14/openai '
clip_model_name : str = " ViT-L-14/openai "
clip_model_path : str = None
# interrogator settings
cache_path : str = ' cache ' # path to store cached text embeddings
download_cache : bool = True # when true, cached embeds are downloaded from huggingface
chunk_size : int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , ' data ' )
device : str = ( " mps " if torch . backends . mps . is_available ( ) else " cuda " if torch . cuda . is_available ( ) else " cpu " )
cache_path : str = " cache " # path to store cached text embeddings
download_cache : bool = (
True # when true, cached embeds are downloaded from huggingface
)
chunk_size : int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , " data " )
device : str = (
" mps "
if torch . backends . mps . is_available ( )
else " cuda "
if torch . cuda . is_available ( )
else " cpu "
)
flavor_intermediate_count : int = 2048
quiet : bool = False # when quiet progress bars are not shown
quiet : bool = False # when quiet progress bars are not shown
class Interrogator ( ) :
class Interrogator :
def __init__ ( self , config : Config ) :
self . config = config
self . device = config . device
# Record which model is on the target device
self . blip_loaded = True
# Load BLIP model (to intended device)
if config . blip_model is None :
if not config . quiet :
print ( " Loading BLIP model... " )
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , ' configs ' )
med_config = os . path . join ( configs_path , ' med_config.json ' )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , " configs " )
med_config = os . path . join ( configs_path , " med_config.json " )
blip_model = blip_decoder (
pretrained = BLIP_MODELS [ config . blip_model_type ] ,
image_size = config . blip_image_eval_size ,
vit = config . blip_model_type ,
med_config = med_config
med_config = med_config ,
)
blip_model . eval ( )
blip_model = blip_model . to ( config . device )
@ -90,12 +101,13 @@ class Interrogator():
else :
self . blip_model = config . blip_model
# Load CLIP (to CPU)
self . load_clip_model ( )
def download_cache ( self , clip_model_name : str ) :
if clip_model_name == ' ViT-L-14/openai ' :
if clip_model_name == " ViT-L-14/openai " :
cache_urls = CACHE_URLS_VITL
elif clip_model_name == ' ViT-H-14/laion2b_s32b_b79k ' :
elif clip_model_name == " ViT-H-14/laion2b_s32b_b79k " :
cache_urls = CACHE_URLS_VITH
else :
# text embeddings will be precomputed and cached locally
@ -103,7 +115,7 @@ class Interrogator():
os . makedirs ( self . config . cache_path , exist_ok = True )
for url in cache_urls :
filepath = os . path . join ( self . config . cache_path , url . split ( ' / ' ) [ - 1 ] )
filepath = os . path . join ( self . config . cache_path , url . split ( " / " ) [ - 1 ] )
if not os . path . exists ( filepath ) :
_download_file ( url , filepath , quiet = self . config . quiet )
@ -115,40 +127,93 @@ class Interrogator():
if not config . quiet :
print ( " Loading CLIP model... " )
clip_model_name , clip_model_pretrained_name = config . clip_model_name . split ( ' / ' , 2 )
self . clip_model , _ , self . clip_preprocess = open_clip . create_model_and_transforms (
clip_model_name , clip_model_pretrained_name = config . clip_model_name . split (
" / " , 2
)
(
self . clip_model ,
_ ,
self . clip_preprocess ,
) = open_clip . create_model_and_transforms (
clip_model_name ,
pretrained = clip_model_pretrained_name ,
precision = ' fp16 ' if config . device == ' cuda ' else ' fp32 ' ,
device = config . device ,
precision = " fp16 " if config . device == " cuda " else " fp32 " ,
device = " cpu " ,
jit = False ,
cache_dir = config . clip_model_path
cache_dir = config . clip_model_path ,
)
self . clip_model . to ( config . device ) . eval ( )
self . clip_model . eval ( )
else :
self . clip_model = config . clip_model
self . clip_preprocess = config . clip_preprocess
self . tokenize = open_clip . get_tokenizer ( clip_model_name )
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribble ' , ' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' , ' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
sites = [
" Artstation " ,
" behance " ,
" cg society " ,
" cgsociety " ,
" deviantart " ,
" dribble " ,
" flickr " ,
" instagram " ,
" pexels " ,
" pinterest " ,
" pixabay " ,
" pixiv " ,
" polycount " ,
" reddit " ,
" shutterstock " ,
" tumblr " ,
" unsplash " ,
" zbrush central " ,
]
trending_list = [ site for site in sites ]
trending_list . extend ( [ " trending on " + site for site in sites ] )
trending_list . extend ( [ " featured on " + site for site in sites ] )
trending_list . extend ( [ site + " contest winner " for site in sites ] )
trending_list . extend ( [ " trending on " + site for site in sites ] )
trending_list . extend ( [ " featured on " + site for site in sites ] )
trending_list . extend ( [ site + " contest winner " for site in sites ] )
raw_artists = _load_list ( config . data_path , ' artists.txt ' )
raw_artists = _load_list ( config . data_path , " artists.txt " )
artists = [ f " by { a } " for a in raw_artists ]
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
if config . download_cache :
self . download_cache ( config . clip_model_name )
self . artists = LabelTable ( artists , " artists " , self . clip_model , self . tokenize , config )
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , self . tokenize , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
self . negative = LabelTable ( _load_list ( config . data_path , ' negative.txt ' ) , " negative " , self . clip_model , self . tokenize , config )
self . artists = LabelTable (
artists , " artists " , self . clip_model , self . tokenize , config
)
self . flavors = LabelTable (
_load_list ( config . data_path , " flavors.txt " ) ,
" flavors " ,
self . clip_model ,
self . tokenize ,
config ,
)
self . mediums = LabelTable (
_load_list ( config . data_path , " mediums.txt " ) ,
" mediums " ,
self . clip_model ,
self . tokenize ,
config ,
)
self . movements = LabelTable (
_load_list ( config . data_path , " movements.txt " ) ,
" movements " ,
self . clip_model ,
self . tokenize ,
config ,
)
self . trendings = LabelTable (
trending_list , " trendings " , self . clip_model , self . tokenize , config
)
self . negative = LabelTable (
_load_list ( config . data_path , " negative.txt " ) ,
" negative " ,
self . clip_model ,
self . tokenize ,
config ,
)
end_time = time . time ( )
if not config . quiet :
@ -158,16 +223,18 @@ class Interrogator():
self ,
image_features : torch . Tensor ,
phrases : List [ str ] ,
best_prompt : str = " " ,
best_sim : float = 0 ,
min_count : int = 8 ,
max_count : int = 32 ,
best_prompt : str = " " ,
best_sim : float = 0 ,
min_count : int = 8 ,
max_count : int = 32 ,
desc = " Chaining " ,
reverse : bool = False
reverse : bool = False ,
) - > str :
phrases = set ( phrases )
if not best_prompt :
best_prompt = self . rank_top ( image_features , [ f for f in phrases ] , reverse = reverse )
best_prompt = self . rank_top (
image_features , [ f for f in phrases ] , reverse = reverse
)
best_sim = self . similarity ( image_features , best_prompt )
phrases . remove ( best_prompt )
curr_prompt , curr_sim = best_prompt , best_sim
@ -187,8 +254,12 @@ class Interrogator():
return False
for idx in tqdm ( range ( max_count ) , desc = desc , disable = self . config . quiet ) :
best = self . rank_top ( image_features , [ f " { curr_prompt } , { f } " for f in phrases ] , reverse = reverse )
flave = best [ len ( curr_prompt ) + 2 : ]
best = self . rank_top (
image_features ,
[ f " { curr_prompt } , { f } " for f in phrases ] ,
reverse = reverse ,
)
flave = best [ len ( curr_prompt ) + 2 : ]
if not check ( flave , idx ) :
break
if _prompt_at_max_len ( curr_prompt , self . tokenize ) :
@ -201,11 +272,22 @@ class Interrogator():
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( self . device )
size = self . config . blip_image_eval_size
gpu_image = transforms . Compose ( [
transforms . Resize ( ( size , size ) , interpolation = InterpolationMode . BICUBIC ) ,
transforms . ToTensor ( ) ,
transforms . Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) )
] ) ( pil_image ) . unsqueeze ( 0 ) . to ( self . device )
gpu_image = (
transforms . Compose (
[
transforms . Resize (
( size , size ) , interpolation = InterpolationMode . BICUBIC
) ,
transforms . ToTensor ( ) ,
transforms . Normalize (
( 0.48145466 , 0.4578275 , 0.40821073 ) ,
( 0.26862954 , 0.26130258 , 0.27577711 ) ,
) ,
]
) ( pil_image )
. unsqueeze ( 0 )
. to ( self . device )
)
with torch . no_grad ( ) :
caption = self . blip_model . generate (
@ -213,7 +295,7 @@ class Interrogator():
sample = False ,
num_beams = self . config . blip_num_beams ,
max_length = self . config . blip_max_length ,
min_length = 5
min_length = 5 ,
)
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( " cpu " )
@ -226,17 +308,65 @@ class Interrogator():
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
return image_features
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
""" Classic mode creates a prompt in a standard format first describing the image,
then listing the artist , trending , movement , and flavor text modifiers . """
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
def _first_bit ( self , image : Image ) - > ( str , torch . Tensor ) :
if self . blip_loaded :
caption = self . generate_caption ( image )
# Move BLIP to RAM
self . blip_model . to ( " cpu " )
# Move CLIP to intended device
self . clip_model . to ( self . device )
image_features = self . image_to_features ( image )
else : # CLIP is loaded
image_features = self . image_to_features ( image )
# Move CLIP to RAM
self . clip_model . to ( " cpu " )
# Move BLIP to intended device
self . blip_model . to ( self . device )
caption = self . generate_caption ( image )
# Toggle `blip_loaded`
self . blip_loaded ^ = True
return caption , image_features
def _first_bit_batch ( self , images : list [ Image ] ) - > ( list [ str ] , list [ torch . Tensor ] ) :
image_features : list [ torch . Tensor ] = [ ]
if self . blip_loaded :
captions = [ self . generate_caption ( img ) for img in images ]
# Move BLIP to RAM
self . blip_model . to ( " cpu " )
# Move CLIP to intended device
self . clip_model . to ( self . device )
image_features = [ self . image_to_features ( img ) for img in images ]
else : # CLIP is loaded
image_features = [ self . image_to_features ( img ) for img in images ]
# Move CLIP to RAM
self . clip_model . to ( " cpu " )
# Move BLIP to intended device
self . blip_model . to ( self . device )
captions = [ self . generate_caption ( img ) for img in images ]
# Toggle `blip_loaded`
self . blip_loaded ^ = True
return captions , image_features
def _interrogate_classic (
self , caption : str , image_features : torch . Tensor , max_flavours : int = 3
) - > str :
medium = self . mediums . rank ( image_features , 1 ) [ 0 ]
artist = self . artists . rank ( image_features , 1 ) [ 0 ]
trending = self . trendings . rank ( image_features , 1 ) [ 0 ]
movement = self . movements . rank ( image_features , 1 ) [ 0 ]
flaves = " , " . join ( self . flavors . rank ( image_features , max_flavors ) )
flaves = " , " . join ( self . flavors . rank ( image_features , max_flavou rs ) )
if caption . startswith ( medium ) :
prompt = f " { caption } { artist } , { trending } , { movement } , { flaves } "
@ -245,41 +375,138 @@ class Interrogator():
return _truncate_to_fit ( prompt , self . tokenize )
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
""" Classic mode creates a prompt in a standard format first describing the image,
then listing the artist , trending , movement , and flavor text modifiers . """
caption , image_features = self . _first_bit ( image )
return self . _interrogate_classic ( caption , image_features , max_flavors )
def interrogate_classic_batch (
self , images : list [ Image ] , max_flavors : int = 3
) - > list [ str ] :
""" Classic mode creates a prompt in a standard format first describing the image,
then listing the artist , trending , movement , and flavor text modifiers .
This function interrogates a batch of images ( more efficient than doing
it individually ) . """
captions , image_features = self . _first_bit_batch ( images )
returns : list [ str ] = [
self . _interrogate_classic ( caption , feature , max_flavors )
for caption , feature in zip ( captions , image_features )
]
return returns
def _interrogate_fast (
self , caption : str , image_features : torch . Tensor , max_flavours : int = 32
) - > str :
merged = _merge_tables (
[ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] ,
self . config ,
)
tops = merged . rank ( image_features , max_flavours )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) , self . tokenize )
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
""" Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode , but the prompts
are less readable . """
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
tops = merged . rank ( image_features , max_flavors )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) , self . tokenize )
caption , image_features = self . _first_bit ( image )
return self . _interrogate_fast ( caption , image_features , max_flavors )
def interrogate_fast_batch ( self , images : list [ Image ] , max_flavors : int = 32 ) - > str :
""" Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode , but the prompts
are less readable .
This function interrogates a batch of images ( more efficient than doing
it individually ) . """
captions , image_features = self . _first_bit_batch ( images )
returns : list [ str ] = [
self . _interrogate_fast ( caption , feature , max_flavors )
for caption , feature in zip ( captions , image_features )
]
return returns
def interrogate_negative ( self , image : Image , max_flavors : int = 32 ) - > str :
""" Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2. """
image_features = self . image_to_features ( image )
flaves = self . flavors . rank ( image_features , self . config . flavor_intermediate_count , reverse = True )
flaves = flaves + self . negative . labels
return self . chain ( image_features , flaves , max_count = max_flavors , reverse = True , desc = " Negative chain " )
if self . blip_loaded : # Move CLIP to intended device
self . blip_model . to ( " cpu " )
self . cli_model . to ( self . device )
def interrogate ( self , image : Image , min_flavors : int = 8 , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
flaves = self . flavors . rank (
image_features , self . config . flavor_intermediate_count , reverse = True
)
flaves = flaves + self . negative . labels
return self . chain (
image_features ,
flaves ,
max_count = max_flavors ,
reverse = True ,
desc = " Negative chain " ,
)
def _interrogate (
self ,
caption : str ,
image_features : torch . Tensor ,
min_flavours : int = 8 ,
max_flavours : int = 32 ,
) - > str :
merged = _merge_tables (
[ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] ,
self . config ,
)
flaves = merged . rank ( image_features , self . config . flavor_intermediate_count )
best_prompt , best_sim = caption , self . similarity ( image_features , caption )
best_prompt = self . chain ( image_features , flaves , best_prompt , best_sim , min_count = min_flavors , max_count = max_flavors , desc = " Flavor chain " )
fast_prompt = self . interrogate_fast ( image , max_flavors )
classic_prompt = self . interrogate_classic ( image , max_flavors )
best_prompt = self . chain (
image_features ,
flaves ,
best_prompt ,
best_sim ,
min_count = min_flavours ,
max_count = max_flavours ,
desc = " Flavor chain " ,
)
fast_prompt = self . _interrogate_fast ( caption , image_features , max_flavours )
classic_prompt = self . interrogate_classic ( caption , image_features , max_flavours )
candidates = [ caption , classic_prompt , fast_prompt , best_prompt ]
return candidates [ np . argmax ( self . similarities ( image_features , candidates ) ) ]
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] , reverse : bool = False ) - > str :
def interrogate (
self , image : Image , min_flavors : int = 8 , max_flavors : int = 32
) - > str :
caption , image_features = self . _first_bit ( image )
return self . _interrogate ( caption , image_features , min_flavors , max_flavors )
def interrogate_batch (
self , images : list [ Image ] , min_flavors : int = 8 , max_flavors : int = 32
) - > list [ str ] :
""" This function interrogates a batch of images (more efficient than doing
it individually ) . """
captions , image_features = self . _first_bit_batch ( images )
returns : list [ str ] = [
self . _interrogate ( caption , features , min_flavors , max_flavors )
for caption , features in zip ( captions , image_features )
]
return returns
def rank_top (
self , image_features : torch . Tensor , text_array : List [ str ] , reverse : bool = False
) - > str :
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
@ -297,7 +524,9 @@ class Interrogator():
similarity = text_features @ image_features . T
return similarity [ 0 ] [ 0 ] . item ( )
def similarities ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > List [ float ] :
def similarities (
self , image_features : torch . Tensor , text_array : List [ str ]
) - > List [ float ] :
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
@ -306,8 +535,10 @@ class Interrogator():
return similarity . T [ 0 ] . tolist ( )
class LabelTable ( ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config ) :
class LabelTable :
def __init__ (
self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config
) :
self . chunk_size = config . chunk_size
self . config = config
self . device = config . device
@ -320,22 +551,30 @@ class LabelTable():
cache_filepath = None
if config . cache_path is not None and desc is not None :
os . makedirs ( config . cache_path , exist_ok = True )
sanitized_name = config . clip_model_name . replace ( ' / ' , ' _ ' ) . replace ( ' @ ' , ' _ ' )
cache_filepath = os . path . join ( config . cache_path , f " { sanitized_name } _ { desc } .pkl " )
sanitized_name = config . clip_model_name . replace ( " / " , " _ " ) . replace ( " @ " , " _ " )
cache_filepath = os . path . join (
config . cache_path , f " { sanitized_name } _ { desc } .pkl "
)
if desc is not None and os . path . exists ( cache_filepath ) :
with open ( cache_filepath , ' rb ' ) as f :
with open ( cache_filepath , " rb " ) as f :
try :
data = pickle . load ( f )
if data . get ( ' hash ' ) == hash :
self . labels = data [ ' labels ' ]
self . embeds = data [ ' embeds ' ]
if data . get ( " hash " ) == hash :
self . labels = data [ " labels " ]
self . embeds = data [ " embeds " ]
except Exception as e :
print ( f " Error loading cached table { desc } : { e } " )
if len ( self . labels ) != len ( self . embeds ) :
self . embeds = [ ]
chunks = np . array_split ( self . labels , max ( 1 , len ( self . labels ) / config . chunk_size ) )
for chunk in tqdm ( chunks , desc = f " Preprocessing { desc } " if desc else None , disable = self . config . quiet ) :
chunks = np . array_split (
self . labels , max ( 1 , len ( self . labels ) / config . chunk_size )
)
for chunk in tqdm (
chunks ,
desc = f " Preprocessing { desc } " if desc else None ,
disable = self . config . quiet ,
) :
text_tokens = self . tokenize ( chunk ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = clip_model . encode_text ( text_tokens )
@ -345,20 +584,31 @@ class LabelTable():
self . embeds . append ( text_features [ i ] )
if cache_filepath is not None :
with open ( cache_filepath , ' wb ' ) as f :
pickle . dump ( {
" labels " : self . labels ,
" embeds " : self . embeds ,
" hash " : hash ,
" model " : config . clip_model_name
} , f )
if self . device == ' cpu ' or self . device == torch . device ( ' cpu ' ) :
with open ( cache_filepath , " wb " ) as f :
pickle . dump (
{
" labels " : self . labels ,
" embeds " : self . embeds ,
" hash " : hash ,
" model " : config . clip_model_name ,
} ,
f ,
)
if self . device == " cpu " or self . device == torch . device ( " cpu " ) :
self . embeds = [ e . astype ( np . float32 ) for e in self . embeds ]
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > str :
def _rank (
self ,
image_features : torch . Tensor ,
text_embeds : torch . Tensor ,
top_count : int = 1 ,
reverse : bool = False ,
) - > str :
top_count = min ( top_count , len ( text_embeds ) )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to ( self . device )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to (
self . device
)
with torch . cuda . amp . autocast ( ) :
similarity = image_features @ text_embeds . T
if reverse :
@ -366,31 +616,44 @@ class LabelTable():
_ , top_labels = similarity . float ( ) . cpu ( ) . topk ( top_count , dim = - 1 )
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
def rank ( self , image_features : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > List [ str ] :
def rank (
self , image_features : torch . Tensor , top_count : int = 1 , reverse : bool = False
) - > List [ str ] :
if len ( self . labels ) < = self . chunk_size :
tops = self . _rank ( image_features , self . embeds , top_count = top_count , reverse = reverse )
tops = self . _rank (
image_features , self . embeds , top_count = top_count , reverse = reverse
)
return [ self . labels [ i ] for i in tops ]
num_chunks = int ( math . ceil ( len ( self . labels ) / self . chunk_size ) )
num_chunks = int ( math . ceil ( len ( self . labels ) / self . chunk_size ) )
keep_per_chunk = int ( self . chunk_size / num_chunks )
top_labels , top_embeds = [ ] , [ ]
for chunk_idx in tqdm ( range ( num_chunks ) , disable = self . config . quiet ) :
start = chunk_idx * self . chunk_size
stop = min ( start + self . chunk_size , len ( self . embeds ) )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk , reverse = reverse )
top_labels . extend ( [ self . labels [ start + i ] for i in tops ] )
top_embeds . extend ( [ self . embeds [ start + i ] for i in tops ] )
start = chunk_idx * self . chunk_size
stop = min ( start + self . chunk_size , len ( self . embeds ) )
tops = self . _rank (
image_features ,
self . embeds [ start : stop ] ,
top_count = keep_per_chunk ,
reverse = reverse ,
)
top_labels . extend ( [ self . labels [ start + i ] for i in tops ] )
top_embeds . extend ( [ self . embeds [ start + i ] for i in tops ] )
tops = self . _rank ( image_features , top_embeds , top_count = top_count )
return [ top_labels [ i ] for i in tops ]
def _download_file ( url : str , filepath : str , chunk_size : int = 64 * 1024 , quiet : bool = False ) :
def _download_file (
url : str , filepath : str , chunk_size : int = 64 * 1024 , quiet : bool = False
) :
r = requests . get ( url , stream = True )
file_size = int ( r . headers . get ( " Content-Length " , 0 ) )
filename = url . split ( " / " ) [ - 1 ]
progress = tqdm ( total = file_size , unit = " B " , unit_scale = True , desc = filename , disable = quiet )
progress = tqdm (
total = file_size , unit = " B " , unit_scale = True , desc = filename , disable = quiet
)
with open ( filepath , " wb " ) as f :
for chunk in r . iter_content ( chunk_size = chunk_size ) :
if chunk :
@ -398,11 +661,15 @@ def _download_file(url: str, filepath: str, chunk_size: int = 64*1024, quiet: bo
progress . update ( len ( chunk ) )
progress . close ( )
def _load_list ( data_path : str , filename : str ) - > List [ str ] :
with open ( os . path . join ( data_path , filename ) , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
with open (
os . path . join ( data_path , filename ) , " r " , encoding = " utf-8 " , errors = " replace "
) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items
def _merge_tables ( tables : List [ LabelTable ] , config : Config ) - > LabelTable :
m = LabelTable ( [ ] , None , None , None , config )
for table in tables :
@ -410,15 +677,17 @@ def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m . embeds . extend ( table . embeds )
return m
def _prompt_at_max_len ( text : str , tokenize ) - > bool :
tokens = tokenize ( [ text ] )
return tokens [ 0 ] [ - 1 ] != 0
def _truncate_to_fit ( text : str , tokenize ) - > str :
parts = text . split ( ' , ' )
parts = text . split ( " , " )
new_text = parts [ 0 ]
for part in parts [ 1 : ] :
if _prompt_at_max_len ( new_text + part , tokenize ) :
break
new_text + = ' , ' + part
new_text + = " , " + part
return new_text