@ -29,20 +29,20 @@ CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
@dataclass
class Config :
# models can optionally be passed in directly
blip_model : BLIP_Decoder = None
blip_model : Optional [ BLIP_Decoder ] = None
clip_model = None
clip_preprocess = None
# blip settings
blip_image_eval_size : int = 384
blip_max_length : int = 32
blip_model_type : str = ' large ' # choose between 'base' or 'large'
blip_model_type : Optional [ str ] = ' large ' # use 'base', 'large' or None
blip_num_beams : int = 8
blip_offload : bool = False
# clip settings
clip_model_name : str = ' ViT-L-14/openai '
clip_model_path : str = None
clip_model_path : Optional [ str ] = None
clip_offload : bool = False
# interrogator settings
@ -68,7 +68,7 @@ class Interrogator():
self . blip_offloaded = True
self . clip_offloaded = True
if config . blip_model is None :
if config . blip_model is None and config . blip_model_type :
if not config . quiet :
print ( " Loading BLIP model... " )
blip_path = os . path . dirname ( inspect . getfile ( blip_decoder ) )
@ -121,17 +121,17 @@ class Interrogator():
trending_list . extend ( [ " featured on " + site for site in sites ] )
trending_list . extend ( [ site + " contest winner " for site in sites ] )
raw_artists = _ load_list( config . data_path , ' artists.txt ' )
raw_artists = load_list ( config . data_path , ' artists.txt ' )
artists = [ f " by { a } " for a in raw_artists ]
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
self . _prepare_clip ( )
self . artists = LabelTable ( artists , " artists " , self . clip_model , self . tokenize , config )
self . flavors = LabelTable ( _ load_list( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , self . tokenize , config )
self . mediums = LabelTable ( _ load_list( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
self . movements = LabelTable ( _ load_list( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
self . negative = LabelTable ( _ load_list( config . data_path , ' negative.txt ' ) , " negative " , self . clip_model , self . tokenize , config )
self . artists = LabelTable ( artists , " artists " , self )
self . flavors = LabelTable ( load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self )
self . mediums = LabelTable ( load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self )
self . movements = LabelTable ( load_list ( config . data_path , ' movements.txt ' ) , " movements " , self )
self . trendings = LabelTable ( trending_list , " trendings " , self )
self . negative = LabelTable ( load_list ( config . data_path , ' negative.txt ' ) , " negative " , self )
end_time = time . time ( )
if not config . quiet :
@ -183,6 +183,7 @@ class Interrogator():
return best_prompt
def generate_caption ( self , pil_image : Image ) - > str :
assert self . blip_model is not None , " No BLIP model loaded. "
self . _prepare_blip ( )
size = self . config . blip_image_eval_size
@ -310,13 +311,14 @@ class Interrogator():
class LabelTable ( ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config ) :
def __init__ ( self , labels : List [ str ] , desc : str , ci : Interrogator ) :
clip_model , config = ci . clip_model , ci . config
self . chunk_size = config . chunk_size
self . config = config
self . device = config . device
self . embeds = [ ]
self . labels = labels
self . tokenize = tokenize
self . tokenize = ci . tokenize
hash = hashlib . sha256 ( " , " . join ( labels ) . encode ( ) ) . hexdigest ( )
sanitized_name = self . config . clip_model_name . replace ( ' / ' , ' _ ' ) . replace ( ' @ ' , ' _ ' )
@ -423,11 +425,6 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
progress . update ( len ( chunk ) )
progress . close ( )
def _load_list ( data_path : str , filename : str ) - > List [ str ] :
with open ( os . path . join ( data_path , filename ) , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items
def _merge_tables ( tables : List [ LabelTable ] , config : Config ) - > LabelTable :
m = LabelTable ( [ ] , None , None , None , config )
for table in tables :
@ -447,3 +444,11 @@ def _truncate_to_fit(text: str, tokenize) -> str:
break
new_text + = ' , ' + part
return new_text
def load_list ( data_path : str , filename : Optional [ str ] = None ) - > List [ str ] :
""" Load a list of strings from a file. """
if filename is not None :
data_path = os . path . join ( data_path , filename )
with open ( data_path , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items