@ -1,8 +1,8 @@
import clip
import hashlib
import inspect
import math
import numpy as np
import open_clip
import os
import pickle
import torch
@ -30,7 +30,7 @@ class Config:
blip_num_beams : int = 8
# clip settings
clip_model_name : str = ' ViT-L/14 '
clip_model_name : str = ' ViT-H-14/laion2b_s32b_b79k '
# interrogator settings
cache_path : str = ' cache '
@ -67,11 +67,14 @@ class Interrogator():
if config . clip_model is None :
if not config . quiet :
print ( " Loading CLIP model... " )
self . clip_model , self . clip_preprocess = clip . load ( config . clip_model_name , device = config . device )
clip_model_name , clip_model_pretrained_name = config . clip_model_name . split ( ' / ' , 2 )
self . clip_model , _ , self . clip_preprocess = open_clip . create_model_and_transforms ( clip_model_name , pretrained = clip_model_pretrained_name )
self . clip_model . to ( config . device ) . eval ( )
else :
self . clip_model = config . clip_model
self . clip_preprocess = config . clip_preprocess
self . tokenize = open_clip . get_tokenizer ( clip_model_name )
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribble ' , ' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' , ' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
trending_list = [ site for site in sites ]
@ -83,11 +86,11 @@ class Interrogator():
artists = [ f " by { a } " for a in raw_artists ]
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
self . artists = LabelTable ( artists , " artists " , self . clip_model , config )
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , config )
self . artists = LabelTable ( artists , " artists " , self . clip_model , self . tokenize , config )
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , self . tokenize , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
def generate_caption ( self , pil_image : Image ) - > str :
size = self . config . blip_image_eval_size
@ -129,14 +132,14 @@ class Interrogator():
else :
prompt = f " { caption } , { medium } { artist } , { trending } , { movement } , { flaves } "
return _truncate_to_fit ( prompt )
return _truncate_to_fit ( prompt , self . tokenize )
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
tops = merged . rank ( image_features , max_flavors )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) , self . tokenize )
def interrogate ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
@ -171,7 +174,7 @@ class Interrogator():
prompt + = " , " + opts [ bit ]
prompts . append ( prompt )
t = LabelTable ( prompts , None , self . clip_model , self . config )
t = LabelTable ( prompts , None , self . clip_model , self . tokenize , self . config )
best_prompt = t . rank ( image_features , 1 ) [ 0 ]
best_sim = self . similarity ( image_features , best_prompt )
@ -192,7 +195,7 @@ class Interrogator():
return best_prompt
def rank_top ( self , image_features , text_array : List [ str ] ) - > str :
text_tokens = clip . tokenize ( [ text for text in text_array ] ) . to ( self . device )
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) :
text_features = self . clip_model . encode_text ( text_tokens ) . float ( )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
@ -205,7 +208,7 @@ class Interrogator():
return text_array [ top_labels [ 0 ] [ 0 ] . numpy ( ) ]
def similarity ( self , image_features , text ) - > np . float32 :
text_tokens = clip . tokenize ( [ text ] ) . to ( self . device )
text_tokens = self . tokenize ( [ text ] ) . to ( self . device )
with torch . no_grad ( ) :
text_features = self . clip_model . encode_text ( text_tokens ) . float ( )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
@ -214,12 +217,13 @@ class Interrogator():
class LabelTable ( ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , config : Config ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config ) :
self . chunk_size = config . chunk_size
self . config = config
self . device = config . device
self . embeds = [ ]
self . labels = labels
self . tokenize = tokenize
hash = hashlib . sha256 ( " , " . join ( labels ) . encode ( ) ) . hexdigest ( )
@ -239,7 +243,7 @@ class LabelTable():
self . embeds = [ ]
chunks = np . array_split ( self . labels , max ( 1 , len ( self . labels ) / config . chunk_size ) )
for chunk in tqdm ( chunks , desc = f " Preprocessing { desc } " if desc else None , disable = self . config . quiet ) :
text_tokens = clip . tokenize ( chunk ) . to ( self . device )
text_tokens = self . tokenize ( chunk ) . to ( self . device )
with torch . no_grad ( ) :
text_features = clip_model . encode_text ( text_tokens ) . float ( )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
@ -291,16 +295,16 @@ def _load_list(data_path, filename) -> List[str]:
return items
def _merge_tables ( tables : List [ LabelTable ] , config : Config ) - > LabelTable :
m = LabelTable ( [ ] , None , None , config )
m = LabelTable ( [ ] , None , None , None , config )
for table in tables :
m . labels . extend ( table . labels )
m . embeds . extend ( table . embeds )
return m
def _truncate_to_fit ( text : str ) - > str :
def _truncate_to_fit ( text : str , tokenize ) - > str :
while True :
try :
_ = clip . tokenize ( [ text ] )
_ = tokenize ( [ text ] )
return text
except :
text = " , " . join ( text . split ( " , " ) [ : - 1 ] )