@ -4,7 +4,6 @@ import math
import numpy as np
import open_clip
import os
import pickle
import requests
import time
import torch
@ -15,7 +14,7 @@ from PIL import Image
from torchvision import transforms
from torchvision . transforms . functional import InterpolationMode
from tqdm import tqdm
from typing import List
from typing import List , Optional
from safetensors . numpy import load_file , save_file
@ -24,23 +23,7 @@ BLIP_MODELS = {
' large ' : ' https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth '
}
CACHE_URLS_VITL = [
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_negative.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.safetensors ' ,
]
CACHE_URLS_VITH = [
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_negative.safetensors ' ,
' https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.safetensors ' ,
]
CACHE_URL_BASE = ' https://huggingface.co/pharma/ci-preprocess/resolve/main/ '
@dataclass
@ -60,6 +43,7 @@ class Config:
# clip settings
clip_model_name : str = ' ViT-L-14/openai '
clip_model_path : str = None
clip_offload : bool = False
# interrogator settings
cache_path : str = ' cache ' # path to store cached text embeddings
@ -70,11 +54,19 @@ class Config:
flavor_intermediate_count : int = 2048
quiet : bool = False # when quiet progress bars are not shown
def apply_low_vram_defaults ( self ) :
self . blip_model_type = ' base '
self . blip_offload = True
self . clip_offload = True
self . chunk_size = 1024
self . flavor_intermediate_count = 1024
class Interrogator ( ) :
def __init__ ( self , config : Config ) :
self . config = config
self . device = config . device
self . blip_offloaded = True
self . clip_offloaded = True
if config . blip_model is None :
if not config . quiet :
@ -97,21 +89,6 @@ class Interrogator():
self . load_clip_model ( )
def download_cache ( self , clip_model_name : str ) :
if clip_model_name == ' ViT-L-14/openai ' :
cache_urls = CACHE_URLS_VITL
elif clip_model_name == ' ViT-H-14/laion2b_s32b_b79k ' :
cache_urls = CACHE_URLS_VITH
else :
# text embeddings will be precomputed and cached locally
return
os . makedirs ( self . config . cache_path , exist_ok = True )
for url in cache_urls :
filepath = os . path . join ( self . config . cache_path , url . split ( ' / ' ) [ - 1 ] )
if not os . path . exists ( filepath ) :
_download_file ( url , filepath , quiet = self . config . quiet )
def load_clip_model ( self ) :
start_time = time . time ( )
config = self . config
@ -129,13 +106,15 @@ class Interrogator():
jit = False ,
cache_dir = config . clip_model_path
)
self . clip_model . to ( config . device ) . eval ( )
self . clip_model . eval ( )
else :
self . clip_model = config . clip_model
self . clip_preprocess = config . clip_preprocess
self . tokenize = open_clip . get_tokenizer ( clip_model_name )
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribble ' , ' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' , ' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
sites = [ ' Artstation ' , ' behance ' , ' cg society ' , ' cgsociety ' , ' deviantart ' , ' dribble ' ,
' flickr ' , ' instagram ' , ' pexels ' , ' pinterest ' , ' pixabay ' , ' pixiv ' , ' polycount ' ,
' reddit ' , ' shutterstock ' , ' tumblr ' , ' unsplash ' , ' zbrush central ' ]
trending_list = [ site for site in sites ]
trending_list . extend ( [ " trending on " + site for site in sites ] )
trending_list . extend ( [ " featured on " + site for site in sites ] )
@ -145,9 +124,7 @@ class Interrogator():
artists = [ f " by { a } " for a in raw_artists ]
artists . extend ( [ f " inspired by { a } " for a in raw_artists ] )
if config . download_cache :
self . download_cache ( config . clip_model_name )
self . _prepare_clip ( )
self . artists = LabelTable ( artists , " artists " , self . clip_model , self . tokenize , config )
self . flavors = LabelTable ( _load_list ( config . data_path , ' flavors.txt ' ) , " flavors " , self . clip_model , self . tokenize , config )
self . mediums = LabelTable ( _load_list ( config . data_path , ' mediums.txt ' ) , " mediums " , self . clip_model , self . tokenize , config )
@ -170,6 +147,8 @@ class Interrogator():
desc = " Chaining " ,
reverse : bool = False
) - > str :
self . _prepare_clip ( )
phrases = set ( phrases )
if not best_prompt :
best_prompt = self . rank_top ( image_features , [ f for f in phrases ] , reverse = reverse )
@ -203,8 +182,8 @@ class Interrogator():
return best_prompt
def generate_caption ( self , pil_image : Image ) - > str :
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( self . device )
self . _prepare_blip ( )
size = self . config . blip_image_eval_size
gpu_image = transforms . Compose ( [
transforms . Resize ( ( size , size ) , interpolation = InterpolationMode . BICUBIC ) ,
@ -220,21 +199,21 @@ class Interrogator():
max_length = self . config . blip_max_length ,
min_length = 5
)
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( " cpu " )
return caption [ 0 ]
def image_to_features ( self , image : Image ) - > torch . Tensor :
self . _prepare_clip ( )
images = self . clip_preprocess ( image ) . unsqueeze ( 0 ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
image_features = self . clip_model . encode_image ( images )
image_features / = image_features . norm ( dim = - 1 , keepdim = True )
return image_features
def interrogate_classic ( self , image : Image , max_flavors : int = 3 ) - > str :
def interrogate_classic ( self , image : Image , max_flavors : int = 3 , caption : Optional [ str ] = None ) - > str :
""" Classic mode creates a prompt in a standard format first describing the image,
then listing the artist , trending , movement , and flavor text modifiers . """
caption = self . generate_caption ( image )
caption = caption or self . generate_caption ( image )
image_features = self . image_to_features ( image )
medium = self . mediums . rank ( image_features , 1 ) [ 0 ]
@ -250,11 +229,11 @@ class Interrogator():
return _truncate_to_fit ( prompt , self . tokenize )
def interrogate_fast ( self , image : Image , max_flavors : int = 32 ) - > str :
def interrogate_fast ( self , image : Image , max_flavors : int = 32 , caption : Optional [ str ] = None ) - > str :
""" Fast mode simply adds the top ranked terms after a caption. It generally results in
better similarity between generated prompt and image than classic mode , but the prompts
are less readable . """
caption = self . generate_caption ( image )
caption = caption or self . generate_caption ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
tops = merged . rank ( image_features , max_flavors )
@ -269,22 +248,22 @@ class Interrogator():
flaves = flaves + self . negative . labels
return self . chain ( image_features , flaves , max_count = max_flavors , reverse = True , desc = " Negative chain " )
def interrogate ( self , image : Image , min_flavors : int = 8 , max_flavors : int = 32 ) - > str :
caption = self . generate_caption ( image )
def interrogate ( self , image : Image , min_flavors : int = 8 , max_flavors : int = 32 , caption : Optional [ str ] = None ) - > str :
caption = caption or self . generate_caption ( image )
image_features = self . image_to_features ( image )
merged = _merge_tables ( [ self . artists , self . flavors , self . mediums , self . movements , self . trendings ] , self . config )
flaves = merged . rank ( image_features , self . config . flavor_intermediate_count )
best_prompt , best_sim = caption , self . similarity ( image_features , caption )
best_prompt = self . chain ( image_features , flaves , best_prompt , best_sim , min_count = min_flavors , max_count = max_flavors , desc = " Flavor chain " )
fast_prompt = self . interrogate_fast ( image , max_flavors )
classic_prompt = self . interrogate_classic ( image , max_flavors )
fast_prompt = self . interrogate_fast ( image , max_flavors , caption = caption )
classic_prompt = self . interrogate_classic ( image , max_flavors , caption = caption )
candidates = [ caption , classic_prompt , fast_prompt , best_prompt ]
return candidates [ np . argmax ( self . similarities ( image_features , candidates ) ) ]
def rank_top ( self , image_features : torch . Tensor , text_array : List [ str ] , reverse : bool = False ) - > str :
self . _prepare_clip ( )
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
@ -295,6 +274,7 @@ class Interrogator():
return text_array [ similarity . argmax ( ) . item ( ) ]
def similarity ( self , image_features : torch . Tensor , text : str ) - > float :
self . _prepare_clip ( )
text_tokens = self . tokenize ( [ text ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
@ -303,6 +283,7 @@ class Interrogator():
return similarity [ 0 ] [ 0 ] . item ( )
def similarities ( self , image_features : torch . Tensor , text_array : List [ str ] ) - > List [ float ] :
self . _prepare_clip ( )
text_tokens = self . tokenize ( [ text for text in text_array ] ) . to ( self . device )
with torch . no_grad ( ) , torch . cuda . amp . autocast ( ) :
text_features = self . clip_model . encode_text ( text_tokens )
@ -310,6 +291,22 @@ class Interrogator():
similarity = text_features @ image_features . T
return similarity . T [ 0 ] . tolist ( )
def _prepare_blip ( self ) :
if self . config . clip_offload and not self . clip_offloaded :
self . clip_model = self . clip_model . to ( ' cpu ' )
self . clip_offloaded = True
if self . blip_offloaded :
self . blip_model = self . blip_model . to ( self . device )
self . blip_offloaded = False
def _prepare_clip ( self ) :
if self . config . blip_offload :
self . blip_model = self . blip_model . to ( ' cpu ' )
self . blip_offloaded = True
if self . clip_offloaded :
self . clip_model = self . clip_model . to ( self . device )
self . clip_offloaded = False
class LabelTable ( ) :
def __init__ ( self , labels : List [ str ] , desc : str , clip_model , tokenize , config : Config ) :
@ -352,23 +349,25 @@ class LabelTable():
if self . config . cache_path is None or desc is None :
return False
# load from old pkl format if it exists
cached_pkl = os . path . join ( self . config . cache_path , f " { sanitized_name } _ { desc } .pkl " )
if os . path . exists ( cached_pkl ) :
with open ( cached_pkl , ' rb ' ) as f :
try :
data = pickle . load ( f )
if data . get ( ' hash ' ) == hash :
self . labels = data [ ' labels ' ]
self . embeds = data [ ' embeds ' ]
return True
except Exception as e :
print ( f " Error loading cached table { desc } : { e } " )
# load from new safetensors format if it exists
cached_safetensors = os . path . join ( self . config . cache_path , f " { sanitized_name } _ { desc } .safetensors " )
if self . config . download_cache and not os . path . exists ( cached_safetensors ) :
download_url = CACHE_URL_BASE + f " { sanitized_name } _ { desc } .safetensors "
try :
os . makedirs ( self . config . cache_path , exist_ok = True )
_download_file ( download_url , cached_safetensors , quiet = self . config . quiet )
except Exception as e :
print ( f " Failed to download { download_url } " )
print ( e )
return False
if os . path . exists ( cached_safetensors ) :
tensors = load_file ( cached_safetensors )
try :
tensors = load_file ( cached_safetensors )
except Exception as e :
print ( f " Failed to load { cached_safetensors } " )
print ( e )
return False
if ' hash ' in tensors and ' embeds ' in tensors :
if np . array_equal ( tensors [ ' hash ' ] , np . array ( [ ord ( c ) for c in hash ] , dtype = np . int8 ) ) :
self . embeds = tensors [ ' embeds ' ]
@ -377,7 +376,6 @@ class LabelTable():
return True
return False
def _rank ( self , image_features : torch . Tensor , text_embeds : torch . Tensor , top_count : int = 1 , reverse : bool = False ) - > str :
top_count = min ( top_count , len ( text_embeds ) )
@ -409,8 +407,11 @@ class LabelTable():
return [ top_labels [ i ] for i in tops ]
def _download_file ( url : str , filepath : str , chunk_size : int = 6 4* 1024 , quiet : bool = False ) :
def _download_file ( url : str , filepath : str , chunk_size : int = 4 * 102 4* 1024 , quiet : bool = False ) :
r = requests . get ( url , stream = True )
if r . status_code != 200 :
return
file_size = int ( r . headers . get ( " Content-Length " , 0 ) )
filename = url . split ( " / " ) [ - 1 ]
progress = tqdm ( total = file_size , unit = " B " , unit_scale = True , desc = filename , disable = quiet )