@ -1,10 +1,11 @@
import clip
import hashlib
import hashlib
import inspect
import inspect
import math
import math
import numpy as np
import numpy as np
import open_clip
import os
import os
import pickle
import pickle
import time
import torch
import torch
from dataclasses import dataclass
from dataclasses import dataclass
@ -28,9 +29,11 @@ class Config:
blip_max_length : int = 32
blip_max_length : int = 32
blip_model_url : str = ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth '
blip_model_url : str = ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth '
blip_num_beams : int = 8
blip_num_beams : int = 8
blip_offload : bool = False
# clip settings
# clip settings
clip_model_name : str = ' ViT-L/14 '
clip_model_name : str = ' ViT-H-14/laion2b_s32b_b79k '
clip_model_path : str = None
# interrogator settings
# interrogator settings
cache_path : str = ' cache '
cache_path : str = ' cache '
@ -64,14 +67,30 @@ class Interrogator():
else :
else :
self . blip_model = config . blip_model
self . blip_model = config . blip_model
self . load_clip_model ( )
def load_clip_model ( self ) :
start_time = time . time ( )
config = self . config
if config . clip_model is None :
if config . clip_model is None :
if not config . quiet :
if not config . quiet :
print ( " Loading CLIP model... " )
print ( " Loading CLIP model... " )
self . clip_model , self . clip_preprocess = clip . load ( config . clip_model_name , device = config . device )
self . clip_model . to ( config . device ) . eval ( )
clip_model_name , clip_model_pretrained_name = config . clip_model_name . split ( ' / ' , 2 )
self . clip_model , _ , self . clip_preprocess = open_clip . create_model_and_transforms (
clip_model_name ,
pretrained = clip_model_pretrained_name ,
precision = ' fp16 ' ,
device = config . device ,
jit = False ,
cache_dir = config . clip_model_path
)
self . clip_model . half ( ) . to ( config . device ) . eval ( )
else :
else :
self . clip_model = config . clip_model
self . clip_model = config . clip_model
self . clip_preprocess = config . clip_preprocess
self . clip_preprocess = config . clip_preprocess
self . tokenize = open_clip . get_tokenizer ( clip_model_name )
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribble ' , ' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' , ' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribble ' , ' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' , ' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
trending_list = [ site for site in sites ]
trending_list = [ site for site in sites ]
@ -83,13 +102,19 @@ class Interrogator():
artists = [ f " by { a } " for a in raw_artists ]
artists = [ f " by { a } " for a in raw_artists ]
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
self . artists = LabelTable ( artists , " artists " , self . clip_model , config )
self . artists = LabelTable ( artists , " artists " , self . clip_model , self . tokenize , config )
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , config )
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , self . tokenize , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , config )
self . movements = LabelTable ( _load_list ( config . data_path , ' movements.txt ' ) , " movements " , self . clip_model , self . tokenize , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , config )
self . trendings = LabelTable ( trending_list , " trendings " , self . clip_model , self . tokenize , config )
end_time = time . time ( )
if not config . quiet :
print ( f " Loaded CLIP model and data in { end_time - start_time : .2f } seconds. " )
def generate_caption ( self , pil_image : Image ) - > str :
def generate_caption ( self , pil_image : Image ) - > str :
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( self . device )
size = self . config . blip_image_eval_size
size = self . config . blip_image_eval_size
gpu_image = transforms . Compose ( [
gpu_image = transforms . Compose ( [
transforms . Resize ( ( size , size ) , interpolation = InterpolationMode . BICUBIC ) ,
transforms . Resize ( ( size , size ) , interpolation = InterpolationMode . BICUBIC ) ,
@ -105,13 +130,15 @@ class Interrogator():
max_length = self . config . blip_max_length ,
max_length = self . config . blip_max_length ,
min_length = 5
min_length = 5
)
)
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( " cpu " )
return caption [ 0 ]
return caption [ 0 ]
def image_to_features ( self , image : Image ) - > torch . Tensor :
def image_to_features ( self , image : Image ) - > torch . Tensor :
images = self . clip_preprocess ( image ) . unsqueeze ( 0 ) . to ( self . device )
images = self . clip_preprocess ( image ) . unsqueeze ( 0 ) . to ( self . device )
with torch . no_grad ( ) :
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
image_features = self . clip_model . encode_image ( images ) . float ( )
image_features = self . clip_model . encode_image ( images )
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
return image_features
return image_features
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
@ -129,14 +156,14 @@ class Interrogator():
else :
else :
prompt = f " { caption } , { medium } { artist } , { trending } , { movement } , { flaves } "
prompt = f " { caption } , { medium } { artist } , { trending } , { movement } , { flaves } "
return _truncate_to_fit ( prompt )
return _truncate_to_fit ( prompt , self . tokenize )
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
caption = self . generate_caption ( image )
image_features = self . image_to_features ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
tops = merged . rank ( image_features , max_flavors )
tops = merged . rank ( image_features , max_flavors )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) )
return _truncate_to_fit ( caption + " , " + " , " . join ( tops ) , self . tokenize )
def interrogate ( self , image : Image , max_flavors : int = 32 ) - > str :
def interrogate ( self , image : Image , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
caption = self . generate_caption ( image )
@ -171,7 +198,7 @@ class Interrogator():
prompt + = " , " + opts [ bit ]
prompt + = " , " + opts [ bit ]
prompts . append ( prompt )
prompts . append ( prompt )
t = LabelTable ( prompts , None , self . clip_model , self . config )
t = LabelTable ( prompts , None , self . clip_model , self . tokenize , self . config )
best_prompt = t . rank ( image_features , 1 ) [ 0 ]
best_prompt = t . rank ( image_features , 1 ) [ 0 ]
best_sim = self . similarity ( image_features , best_prompt )
best_sim = self . similarity ( image_features , best_prompt )
@ -179,47 +206,41 @@ class Interrogator():
extended_flavors = set ( flaves )
extended_flavors = set ( flaves )
for _ in tqdm ( range ( max_flavors ) , desc = " Flavor chain " , disable = self . config . quiet ) :
for _ in tqdm ( range ( max_flavors ) , desc = " Flavor chain " , disable = self . config . quiet ) :
try :
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in extended_flavors ] )
best = self . rank_top ( image_features , [ f " { best_prompt } , { f } " for f in extended_flavors ] )
flave = best [ len ( best_prompt ) + 2 : ]
flave = best [ len ( best_prompt ) + 2 : ]
if not check ( flave ) :
if not check ( flave ) :
break
extended_flavors . remove ( flave )
except :
# exceeded max prompt length
break
break
if _prompt_at_max_len ( best_prompt , self . tokenize ) :
break
extended_flavors . remove ( flave )
return best_prompt
return best_prompt
def rank_top ( self , image_features , text_array : List [ str ] ) - > str :
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > str :
text_tokens = clip . tokenize ( [ text for text in text_array ] ) . to ( self . device )
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) :
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens ) . float ( )
text_features = self . clip_model . encode_text ( text_tokens )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
similarity = text_features @ image_features . T
similarity = torch . zeros ( ( 1 , len ( text_array ) ) , device = self . device )
return text_array [ similarity . argmax ( ) . item ( ) ]
for i in range ( image_features . shape [ 0 ] ) :
similarity + = ( image_features [ i ] . unsqueeze ( 0 ) @ text_features . T ) . softmax ( dim = - 1 )
_ , top_labels = similarity . cpu ( ) . topk ( 1 , dim = - 1 )
def similarity ( self , image_features : torch . Tensor , text : str ) - > float :
return text_array [ top_labels [ 0 ] [ 0 ] . numpy ( ) ]
text_tokens = self . tokenize ( [ text ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
def similarity ( self , image_features , text ) - > np . float32 :
text_features = self . clip_model . encode_text ( text_tokens )
text_tokens = clip . tokenize ( [ text ] ) . to ( self . device )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
with torch . no_grad ( ) :
similarity = text_features @ image_features . T
text_features = self . clip_model . encode_text ( text_tokens ) . float ( )
return similarity [ 0 ] [ 0 ] . item ( )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
similarity = text_features . cpu ( ) . numpy ( ) @ image_features . cpu ( ) . numpy ( ) . T
return similarity [ 0 ] [ 0 ]
class LabelTable ( ) :
class LabelTable ( ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , config : Config ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config ) :
self . chunk_size = config . chunk_size
self . chunk_size = config . chunk_size
self . config = config
self . config = config
self . device = config . device
self . device = config . device
self . embeds = [ ]
self . embeds = [ ]
self . labels = labels
self . labels = labels
self . tokenize = tokenize
hash = hashlib . sha256 ( " , " . join ( labels ) . encode ( ) ) . hexdigest ( )
hash = hashlib . sha256 ( " , " . join ( labels ) . encode ( ) ) . hexdigest ( )
@ -239,11 +260,11 @@ class LabelTable():
self . embeds = [ ]
self . embeds = [ ]
chunks = np . array_split ( self . labels , max ( 1 , len ( self . labels ) / config . chunk_size ) )
chunks = np . array_split ( self . labels , max ( 1 , len ( self . labels ) / config . chunk_size ) )
for chunk in tqdm ( chunks , desc = f " Preprocessing { desc } " if desc else None , disable = self . config . quiet ) :
for chunk in tqdm ( chunks , desc = f " Preprocessing { desc } " if desc else None , disable = self . config . quiet ) :
text_tokens = clip . tokenize ( chunk ) . to ( self . device )
text_tokens = self . tokenize ( chunk ) . to ( self . device )
with torch . no_grad ( ) :
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = clip_model . encode_text ( text_tokens ) . float ( )
text_features = clip_model . encode_text ( text_tokens )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
text_features / = text_features . norm ( dim = - 1 , keepdim = True )
text_features = text_features . half ( ) . cpu ( ) . numpy ( )
text_features = text_features . half ( ) . cpu ( ) . numpy ( )
for i in range ( text_features . shape [ 0 ] ) :
for i in range ( text_features . shape [ 0 ] ) :
self . embeds . append ( text_features [ i ] )
self . embeds . append ( text_features [ i ] )
@ -256,16 +277,15 @@ class LabelTable():
" model " : config . clip_model_name
" model " : config . clip_model_name
} , f )
} , f )
def _rank ( self , image_features , text_embeds , top_count = 1 ) :
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 ) - > str :
top_count = min ( top_count , len ( text_embeds ) )
top_count = min ( top_count , len ( text_embeds ) )
similarity = torch . zeros ( ( 1 , len ( text_embeds ) ) ) . to ( self . device )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . to ( self . device )
text_embeds = torch . stack ( [ torch . from_numpy ( t ) for t in text_embeds ] ) . float ( ) . to ( self . device )
with torch . cuda . amp . autocast ( ) :
for i in range ( image_features . shape [ 0 ] ) :
similarity = image_features @ text_embeds . T
similarity + = ( image_features [ i ] . unsqueeze ( 0 ) @ text_embeds . T ) . softmax ( dim = - 1 )
_ , top_labels = similarity . float ( ) . cpu ( ) . topk ( top_count , dim = - 1 )
_ , top_labels = similarity . cpu ( ) . topk ( top_count , dim = - 1 )
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
return [ top_labels [ 0 ] [ i ] . numpy ( ) for i in range ( top_count ) ]
def rank ( self , image_features , top_count = 1 ) - > List [ str ] :
def rank ( self , image_features : torch . Tensor , top_count : i nt = 1 ) - > List [ str ] :
if len ( self . labels ) < = self . chunk_size :
if len ( self . labels ) < = self . chunk_size :
tops = self . _rank ( image_features , self . embeds , top_count = top_count )
tops = self . _rank ( image_features , self . embeds , top_count = top_count )
return [ self . labels [ i ] for i in tops ]
return [ self . labels [ i ] for i in tops ]
@ -285,23 +305,27 @@ class LabelTable():
return [ top_labels [ i ] for i in tops ]
return [ top_labels [ i ] for i in tops ]
def _load_list ( data_path , filename ) - > List [ str ] :
def _load_list ( data_path : str , filename : str ) - > List [ str ] :
with open ( os . path . join ( data_path , filename ) , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
with open ( os . path . join ( data_path , filename ) , ' r ' , encoding = ' utf-8 ' , errors = ' replace ' ) as f :
items = [ line . strip ( ) for line in f . readlines ( ) ]
items = [ line . strip ( ) for line in f . readlines ( ) ]
return items
return items
def _merge_tables ( tables : List [ LabelTable ] , config : Config ) - > LabelTable :
def _merge_tables ( tables : List [ LabelTable ] , config : Config ) - > LabelTable :
m = LabelTable ( [ ] , None , None , config )
m = LabelTable ( [ ] , None , None , None , config )
for table in tables :
for table in tables :
m . labels . extend ( table . labels )
m . labels . extend ( table . labels )
m . embeds . extend ( table . embeds )
m . embeds . extend ( table . embeds )
return m
return m
def _truncate_to_fit ( text : str ) - > str :
def _prompt_at_max_len ( text : str , tokenize ) - > bool :
while True :
tokens = tokenize ( [ text ] )
try :
return tokens [ 0 ] [ - 1 ] != 0
_ = clip . tokenize ( [ text ] )
return text
def _truncate_to_fit ( text : str , tokenize ) - > str :
except :
parts = text . split ( ' , ' )
text = " , " . join ( text . split ( " , " ) [ : - 1 ] )
new_text = parts [ 0 ]
for part in parts [ 1 : ] :
if _prompt_at_max_len ( new_text + part , tokenize ) :
break
new_text + = ' , ' + part
return new_text