@ -37,7 +37,8 @@ class Config:
chunk_size : int = 2048
chunk_size : int = 2048
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , ' data ' )
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , ' data ' )
device : str = ' cuda ' if torch . cuda . is_available ( ) else ' cpu '
device : str = ' cuda ' if torch . cuda . is_available ( ) else ' cpu '
flavor_intermediate_count : int = 2048
flavor_intermediate_count : int = 2048
quiet : bool = False # when quiet progress bars are not shown
class Interrogator ( ) :
class Interrogator ( ) :
@ -46,7 +47,8 @@ class Interrogator():
self . device = config . device
self . device = config . device
if config . blip_model is None :
if config . blip_model is None :
print ( " Loading BLIP model... " )
if not config . quiet :
print ( " Loading BLIP model... " )
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , ' configs ' )
configs_path = os . path . join ( os . path . dirname ( blip_path ) , ' configs ' )
med_config = os . path . join ( configs_path , ' med_config.json ' )
med_config = os . path . join ( configs_path , ' med_config.json ' )
@ -63,7 +65,8 @@ class Interrogator():
self . blip_model = config . blip_model
self . blip_model = config . blip_model
if config . clip_model is None :
if config . clip_model is None :
print ( " Loading CLIP model... " )
if not config . quiet :
print ( " Loading CLIP model... " )
self . clip_model , self . clip_preprocess = clip . load ( config . clip_model_name , device = config . device )
self . clip_model , self . clip_preprocess = clip . load ( config . clip_model_name , device = config . device )
self . clip_model . to ( config . device ) . eval ( )
self . clip_model . to ( config . device ) . eval ( )
else :
else :
@ -111,7 +114,7 @@ class Interrogator():
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
return image_features
return image_features
def interrogate_classic ( self , image : Image , max_flave s : int = 3 ) - > str :
def interrogate_classic ( self , image : Image , max_flavor s : int = 3 ) - > str :
caption = self . generate_caption ( image )
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
image_features = self . image_to_features ( image )
@ -119,7 +122,7 @@ class Interrogator():
artist = self . artists . rank ( image_features , 1 ) [ 0 ]
artist = self . artists . rank ( image_features , 1 ) [ 0 ]
trending = self . trendings . rank ( image_features , 1 ) [ 0 ]
trending = self . trendings . rank ( image_features , 1 ) [ 0 ]
movement = self . movements . rank ( image_features , 1 ) [ 0 ]
movement = self . movements . rank ( image_features , 1 ) [ 0 ]
flaves = " , " . join ( self . flavors . rank ( image_features , max_flave s ) )
flaves = " , " . join ( self . flavors . rank ( image_features , max_flavor s ) )
if caption . startswith ( medium ) :
if caption . startswith ( medium ) :
prompt = f " { caption } { artist } , { trending } , { movement } , { flaves } "
prompt = f " { caption } { artist } , { trending } , { movement } , { flaves } "
@ -128,14 +131,14 @@ class Interrogator():
return _truncate_to_fit ( prompt )
return _truncate_to_fit ( prompt )
def interrogate_fast ( self , image : Image ) - > str :
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
tops = merged . rank ( image_features , 32 )
tops = merged . rank ( image_features , max_flavors )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) )
def interrogate ( self , image : Image ) - > str :
def interrogate ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
image_features = self . image_to_features ( image )
@ -175,7 +178,7 @@ class Interrogator():
check_multi_batch ( [ best_medium , best_artist , best_trending , best_movement ] )
check_multi_batch ( [ best_medium , best_artist , best_trending , best_movement ] )
extended_flavors = set ( flaves )
extended_flavors = set ( flaves )
for _ in tqdm ( range ( 25 ) , desc = " Flavor chain " ) :
for _ in tqdm ( range ( max_flavors ) , desc = " Flavor chain " , disable = self . config . quiet ) :
try :
try :
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in extended_flavors ] )
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in extended_flavors ] )
flave = best [ len ( best_prompt ) + 2 : ]
flave = best [ len ( best_prompt ) + 2 : ]
@ -213,9 +216,10 @@ class Interrogator():
class LabelTable ( ) :
class LabelTable ( ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , config : Config ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , config : Config ) :
self . chunk_size = config . chunk_size
self . chunk_size = config . chunk_size
self . config = config
self . device = config . device
self . device = config . device
self . labels = labels
self . embeds = [ ]
self . embeds = [ ]
self . labels = labels
hash = hashlib . sha256 ( " , " . join ( labels ) . encode ( ) ) . hexdigest ( )
hash = hashlib . sha256 ( " , " . join ( labels ) . encode ( ) ) . hexdigest ( )
@ -234,7 +238,7 @@ class LabelTable():
if len ( self . labels ) != len ( self . embeds ) :
if len ( self . labels ) != len ( self . embeds ) :
self . embeds = [ ]
self . embeds = [ ]
chunks = np . array_split ( self . labels , max ( 1 , len ( self . labels ) / config . chunk_size ) )
chunks = np . array_split ( self . labels , max ( 1 , len ( self . labels ) / config . chunk_size ) )
for chunk in tqdm ( chunks , desc = f " Preprocessing { desc } " if desc else None ) :
for chunk in tqdm ( chunks , desc = f " Preprocessing { desc } " if desc else None , disable = self . config . quiet ) :
text_tokens = clip . tokenize ( chunk ) . to ( self . device )
text_tokens = clip . tokenize ( chunk ) . to ( self . device )
with torch . no_grad ( ) :
with torch . no_grad ( ) :
text_features = clip_model . encode_text ( text_tokens ) . float ( )
text_features = clip_model . encode_text ( text_tokens ) . float ( )
@ -270,7 +274,7 @@ class LabelTable():
keep_per_chunk = int ( self . chunk_size / num_chunks )
keep_per_chunk = int ( self . chunk_size / num_chunks )
top_labels , top_embeds = [ ] , [ ]
top_labels , top_embeds = [ ] , [ ]
for chunk_idx in tqdm ( range ( num_chunks ) ) :
for chunk_idx in tqdm ( range ( num_chunks ) , disable = self . config . quiet ) :
start = chunk_idx * self . chunk_size
start = chunk_idx * self . chunk_size
stop = min ( start + self . chunk_size , len ( self . embeds ) )
stop = min ( start + self . chunk_size , len ( self . embeds ) )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk )