@ -107,11 +107,51 @@ class Interrogator():
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
self . negative = LabelTable ( _load_list ( config . data_path , ' negative.txt ' ) , " negative " , self . clip_model , self . tokenize , config )
end_time = time . time ( )
end_time = time . time ( )
if not config . quiet :
if not config . quiet :
print ( f " Loaded CLIP model and data in { end_time - start_time : .2f } seconds. " )
print ( f " Loaded CLIP model and data in { end_time - start_time : .2f } seconds. " )
def chain (
self ,
image_features : torch . Tensor ,
phrases : List [ str ] ,
best_prompt : str = " " ,
best_sim : float = 0 ,
max_count : int = 32 ,
desc = " Chaining " ,
reverse : bool = False
) - > str :
phrases = set ( phrases )
if not best_prompt :
best_prompt = self . rank_top ( image_features , [ f for f in phrases ] , reverse = reverse )
best_sim = self . similarity ( image_features , best_prompt )
phrases . remove ( best_prompt )
def check ( addition : str ) - > bool :
nonlocal best_prompt , best_sim
prompt = best_prompt + " , " + addition
sim = self . similarity ( image_features , prompt )
if reverse :
sim = - sim
if sim > best_sim :
best_sim = sim
best_prompt = prompt
return True
return False
for _ in tqdm ( range ( max_count ) , desc = desc , disable = self . config . quiet ) :
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in phrases ] , reverse = reverse )
flave = best [ len ( best_prompt ) + 2 : ]
if not check ( flave ) :
break
if _prompt_at_max_len ( best_prompt , self . tokenize ) :
break
phrases . remove ( flave )
return best_prompt
def generate_caption ( self , pil_image : Image ) - > str :
def generate_caption ( self , pil_image : Image ) - > str :
if self . config . blip_offload :
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( self . device )
self . blip_model = self . blip_model . to ( self . device )
@ -204,24 +244,16 @@ class Interrogator():
check_multi_batch ( [ best_medium , best_artist , best_trending , best_movement ] )
check_multi_batch ( [ best_medium , best_artist , best_trending , best_movement ] )
extended_flavors = set ( flaves )
return self . chain ( image_features , flaves , best_prompt , best_sim , max_count = max_flavors , desc = " Flavor chain " )
for _ in tqdm ( range ( max_flavors ) , desc = " Flavor chain " , disable = self . config . quiet ) :
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in extended_flavors ] )
flave = best [ len ( best_prompt ) + 2 : ]
if not check ( flave ) :
break
if _prompt_at_max_len ( best_prompt , self . tokenize ) :
break
extended_flavors . remove ( flave )
return best_prompt
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > str :
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] , reverse : bool = False ) - > str :
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
text_features = self . clip_model . encode_text ( text_tokens )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
similarity = text_features @ image_features . T
similarity = text_features @ image_features . T
if reverse :
similarity = - similarity
return text_array [ similarity . argmax ( ) . item ( ) ]
return text_array [ similarity . argmax ( ) . item ( ) ]
def similarity ( self , image_features : torch . Tensor , text : str ) - > float :
def similarity ( self , image_features : torch . Tensor , text : str ) - > float :
@ -283,17 +315,19 @@ class LabelTable():
if self . device == ' cpu ' or self . device == torch . device ( ' cpu ' ) :
if self . device == ' cpu ' or self . device == torch . device ( ' cpu ' ) :
self . embeds = [ e . astype ( np . float32 ) for e in self . embeds ]
self . embeds = [ e . astype ( np . float32 ) for e in self . embeds ]
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 ) - > str :
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > str :
top_count = min ( top_count , len ( text_embeds ) )
top_count = min ( top_count , len ( text_embeds ) )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to ( self . device )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to ( self . device )
with torch . cuda . amp . autocast ( ) :
with torch . cuda . amp . autocast ( ) :
similarity = image_features @ text_embeds . T
similarity = image_features @ text_embeds . T
if reverse :
similarity = - similarity
_ , top_labels = similarity . float ( ) . cpu ( ) . topk ( top_count , dim = - 1 )
_ , top_labels = similarity . float ( ) . cpu ( ) . topk ( top_count , dim = - 1 )
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
def rank ( self , image_features : torch . Tensor , top_count : int = 1 ) - > List [ str ] :
def rank ( self , image_features : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > List [ str ] :
if len ( self . labels ) < = self . chunk_size :
if len ( self . labels ) < = self . chunk_size :
tops = self . _rank ( image_features , self . embeds , top_count = top_count )
tops = self . _rank ( image_features , self . embeds , top_count = top_count , reverse = reverse )
return [ self . labels [ i ] for i in tops ]
return [ self . labels [ i ] for i in tops ]
num_chunks = int ( math . ceil ( len ( self . labels ) / self . chunk_size ) )
num_chunks = int ( math . ceil ( len ( self . labels ) / self . chunk_size ) )
@ -303,7 +337,7 @@ class LabelTable():
for chunk_idx in tqdm ( range ( num_chunks ) , disable = self . config . quiet ) :
for chunk_idx in tqdm ( range ( num_chunks ) , disable = self . config . quiet ) :
start = chunk_idx * self . chunk_size
start = chunk_idx * self . chunk_size
stop = min ( start + self . chunk_size , len ( self . embeds ) )
stop = min ( start + self . chunk_size , len ( self . embeds ) )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk , reverse = reverse )
top_labels . extend ( [ self . labels [ start + i ] for i in tops ] )
top_labels . extend ( [ self . labels [ start + i ] for i in tops ] )
top_embeds . extend ( [ self . embeds [ start + i ] for i in tops ] )
top_embeds . extend ( [ self . embeds [ start + i ] for i in tops ] )