@ -5,6 +5,7 @@ import numpy as np
import open_clip
import os
import pickle
import requests
import time
import torch
@ -21,6 +22,23 @@ BLIP_MODELS = {
' large ' : ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth '
}
CACHE_URLS_VITL = [
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl ' ,
]
CACHE_URLS_VITH = [
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl ' ,
]
@dataclass
class Config :
# models can optionally be passed in directly
@ -40,13 +58,15 @@ class Config:
clip_model_path : str = None
# interrogator settings
cache_path : str = ' cache '
chunk_size : int = 2048
cache_path : str = ' cache ' # path to store cached text embeddings
download_cache : bool = True # when true, cached embeds are downloaded from huggingface
chunk_size : int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path : str = os . path . join ( os . path . dirname ( __file__ ) , ' data ' )
device : str = ( " mps " if torch . backends . mps . is_available ( ) else " cuda " if torch . cuda . is_available ( ) else " cpu " )
flavor_intermediate_count : int = 2048
quiet : bool = False # when quiet progress bars are not shown
class Interrogator ( ) :
def __init__ ( self , config : Config ) :
self . config = config
@ -72,6 +92,21 @@ class Interrogator():
self . load_clip_model ( )
def download_cache ( self , clip_model_name : str ) :
if clip_model_name == ' ViT-L-14/openai ' :
cache_urls = CACHE_URLS_VITL
elif clip_model_name == ' ViT-H-14/laion2b_s32b_b79k ' :
cache_urls = CACHE_URLS_VITH
else :
# text embeddings will be precomputed and cached locally
return
os . makedirs ( self . config . cache_path , exist_ok = True )
for url in cache_urls :
filepath = os . path . join ( self . config . cache_path , url . split ( ' / ' ) [ - 1 ] )
if not os . path . exists ( filepath ) :
_download_file ( url , filepath , quiet = self . config . quiet )
def load_clip_model ( self ) :
start_time = time . time ( )
config = self . config
@ -105,16 +140,58 @@ class Interrogator():
artists = [ f " by { a } " for a in raw_artists ]
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
self . download_cache ( config . clip_model_name )
self . artists = LabelTable ( artists , " artists " , self . clip_model , self . tokenize , config )
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , self . tokenize , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
self . negative = LabelTable ( _load_list ( config . data_path , ' negative.txt ' ) , " negative " , self . clip_model , self . tokenize , config )
end_time = time . time ( )
if not config . quiet :
print ( f " Loaded CLIP model and data in { end_time - start_time : .2f } seconds. " )
def chain (
self ,
image_features : torch . Tensor ,
phrases : List [ str ] ,
best_prompt : str = " " ,
best_sim : float = 0 ,
max_count : int = 32 ,
desc = " Chaining " ,
reverse : bool = False
) - > str :
phrases = set ( phrases )
if not best_prompt :
best_prompt = self . rank_top ( image_features , [ f for f in phrases ] , reverse = reverse )
best_sim = self . similarity ( image_features , best_prompt )
phrases . remove ( best_prompt )
def check ( addition : str ) - > bool :
nonlocal best_prompt , best_sim
prompt = best_prompt + " , " + addition
sim = self . similarity ( image_features , prompt )
if reverse :
sim = - sim
if sim > best_sim :
best_sim = sim
best_prompt = prompt
return True
return False
for _ in tqdm ( range ( max_count ) , desc = desc , disable = self . config . quiet ) :
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in phrases ] , reverse = reverse )
flave = best [ len ( best_prompt ) + 2 : ]
if not check ( flave ) :
break
if _prompt_at_max_len ( best_prompt , self . tokenize ) :
break
phrases . remove ( flave )
return best_prompt
def generate_caption ( self , pil_image : Image ) - > str :
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( self . device )
@ -145,6 +222,8 @@ class Interrogator():
return image_features
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
""" Classic mode creates a prompt in a standard format first describing the image,
then listing the artist , trending , movement , and flavor text modifiers . """
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
@ -162,69 +241,43 @@ class Interrogator():
return _truncate_to_fit ( prompt , self . tokenize )
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
""" Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode , but the prompts
are less readable . """
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
tops = merged . rank ( image_features , max_flavors )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) , self . tokenize )
def interrogate_negative ( self , image : Image , max_flavors : int = 32 ) - > str :
""" Negative mode chains together the most dissimilar terms to the image. It can be used
to help build a negative prompt to pair with the regular positive prompt and often
improve the results of generated images particularly with Stable Diffusion 2. """
image_features = self . image_to_features ( image )
flaves = self . flavors . rank ( image_features , self . config . flavor_intermediate_count , reverse = True )
flaves = flaves + self . negative . labels
return self . chain ( image_features , flaves , max_count = max_flavors , reverse = True , desc = " Negative chain " )
def interrogate ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
flaves = self . flavors . rank ( image_features , self . config . flavor_intermediate_count )
best_medium = self . mediums . rank ( image_features , 1 ) [ 0 ]
best_artist = self . artists . rank ( image_features , 1 ) [ 0 ]
best_trending = self . trendings . rank ( image_features , 1 ) [ 0 ]
best_movement = self . movements . rank ( image_features , 1 ) [ 0 ]
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
flaves = merged . rank ( image_features , self . config . flavor_intermediate_count )
best_prompt = caption
best_sim = self . similarity ( image_features , best_prompt )
def check ( addition : str ) - > bool :
nonlocal best_prompt , best_sim
prompt = best_prompt + " , " + addition
sim = self . similarity ( image_features , prompt )
if sim > best_sim :
best_sim = sim
best_prompt = prompt
return True
return False
def check_multi_batch ( opts : List [ str ] ) :
nonlocal best_prompt , best_sim
prompts = [ ]
for i in range ( 2 * * len ( opts ) ) :
prompt = best_prompt
for bit in range ( len ( opts ) ) :
if i & ( 1 << bit ) :
prompt + = " , " + opts [ bit ]
prompts . append ( prompt )
t = LabelTable ( prompts , None , self . clip_model , self . tokenize , self . config )
best_prompt = t . rank ( image_features , 1 ) [ 0 ]
best_sim = self . similarity ( image_features , best_prompt )
check_multi_batch ( [ best_medium , best_artist , best_trending , best_movement ] )
extended_flavors = set ( flaves )
for _ in tqdm ( range ( max_flavors ) , desc = " Flavor chain " , disable = self . config . quiet ) :
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in extended_flavors ] )
flave = best [ len ( best_prompt ) + 2 : ]
if not check ( flave ) :
break
if _prompt_at_max_len ( best_prompt , self . tokenize ) :
break
extended_flavors . remove ( flave )
return best_prompt
return self . chain ( image_features , flaves , best_prompt , best_sim , max_count = max_flavors , desc = " Flavor chain " )
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > str :
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] , reverse : bool = False ) - > str :
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
similarity = text_features @ image_features . T
if reverse :
similarity = - similarity
return text_array [ similarity . argmax ( ) . item ( ) ]
def similarity ( self , image_features : torch . Tensor , text : str ) - > float :
@ -235,6 +288,14 @@ class Interrogator():
similarity = text_features @ image_features . T
return similarity [ 0 ] [ 0 ] . item ( )
def similarities ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > List [ float ] :
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
similarity = text_features @ image_features . T
return similarity . T [ 0 ] . tolist ( )
class LabelTable ( ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config ) :
@ -286,17 +347,19 @@ class LabelTable():
if self . device == ' cpu ' or self . device == torch . device ( ' cpu ' ) :
self . embeds = [ e . astype ( np . float32 ) for e in self . embeds ]
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 ) - > str :
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > str :
top_count = min ( top_count , len ( text_embeds ) )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to ( self . device )
with torch . cuda . amp . autocast ( ) :
similarity = image_features @ text_embeds . T
if reverse :
similarity = - similarity
_ , top_labels = similarity . float ( ) . cpu ( ) . topk ( top_count , dim = - 1 )
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
def rank ( self , image_features : torch . Tensor , top_count : int = 1 ) - > List [ str ] :
def rank ( self , image_features : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > List [ str ] :
if len ( self . labels ) < = self . chunk_size :
tops = self . _rank ( image_features , self . embeds , top_count = top_count )
tops = self . _rank ( image_features , self . embeds , top_count = top_count , reverse = reverse )
return [ self . labels [ i ] for i in tops ]
num_chunks = int ( math . ceil ( len ( self . labels ) / self . chunk_size ) )
@ -306,7 +369,7 @@ class LabelTable():
for chunk_idx in tqdm ( range ( num_chunks ) , disable = self . config . quiet ) :
start = chunk_idx * self . chunk_size
stop = min ( start + self . chunk_size , len ( self . embeds ) )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk )
tops = self . _rank ( image_features , self . embeds [ start : stop ] , top_count = keep_per_chunk , reverse = reverse )
top_labels . extend ( [ self . labels [ start + i ] for i in tops ] )
top_embeds . extend ( [ self . embeds [ start + i ] for i in tops ] )
@ -314,6 +377,18 @@ class LabelTable():
return [ top_labels [ i ] for i in tops ]
def _download_file ( url : str , filepath : str , chunk_size : int = 64 * 1024 , quiet : bool = False ) :
r = requests . get ( url , stream = True )
file_size = int ( r . headers . get ( " Content-Length " , 0 ) )
filename = url . split ( " / " ) [ - 1 ]
progress = tqdm ( total = file_size , unit = " B " , unit_scale = True , desc = filename , disable = quiet )
with open ( filepath , " wb " ) as f :
for chunk in r . iter_content ( chunk_size = chunk_size ) :
if chunk :
f . write ( chunk )
progress . update ( len ( chunk ) )
progress . close ( )
def _load_list ( data_path : str , filename : str ) - > List [ str ] :
with open ( os . path . join ( data_path , filename ) , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]