import os
from transformers import CLIPTokenizer
import comfy . ops
import torch
import traceback
import zipfile
from . import model_management
import comfy . clip_model
import json
import logging
def gen_empty_tokens ( special_tokens , length ) :
start_token = special_tokens . get ( " start " , None )
end_token = special_tokens . get ( " end " , None )
pad_token = special_tokens . get ( " pad " )
output = [ ]
if start_token is not None :
output . append ( start_token )
if end_token is not None :
output . append ( end_token )
output + = [ pad_token ] * ( length - len ( output ) )
return output
class ClipTokenWeightEncoder :
def encode_token_weights ( self , token_weight_pairs ) :
to_encode = list ( )
max_token_len = 0
has_weights = False
for x in token_weight_pairs :
tokens = list ( map ( lambda a : a [ 0 ] , x ) )
max_token_len = max ( len ( tokens ) , max_token_len )
has_weights = has_weights or not all ( map ( lambda a : a [ 1 ] == 1.0 , x ) )
to_encode . append ( tokens )
sections = len ( to_encode )
if has_weights or sections == 0 :
to_encode . append ( gen_empty_tokens ( self . special_tokens , max_token_len ) )
out , pooled = self . encode ( to_encode )
if pooled is not None :
first_pooled = pooled [ 0 : 1 ] . to ( model_management . intermediate_device ( ) )
else :
first_pooled = pooled
output = [ ]
for k in range ( 0 , sections ) :
z = out [ k : k + 1 ]
if has_weights :
z_empty = out [ - 1 ]
for i in range ( len ( z ) ) :
for j in range ( len ( z [ i ] ) ) :
weight = token_weight_pairs [ k ] [ j ] [ 1 ]
if weight != 1.0 :
z [ i ] [ j ] = ( z [ i ] [ j ] - z_empty [ j ] ) * weight + z_empty [ j ]
output . append ( z )
if ( len ( output ) == 0 ) :
return out [ - 1 : ] . to ( model_management . intermediate_device ( ) ) , first_pooled
return torch . cat ( output , dim = - 2 ) . to ( model_management . intermediate_device ( ) ) , first_pooled
class SDClipModel ( torch . nn . Module , ClipTokenWeightEncoder ) :
""" Uses the CLIP transformer encoder for text (from huggingface) """
LAYERS = [
" last " ,
" pooled " ,
" hidden "
]
def __init__ ( self , version = " openai/clip-vit-large-patch14 " , device = " cpu " , max_length = 77 ,
freeze = True , layer = " last " , layer_idx = None , textmodel_json_config = None , dtype = None , model_class = comfy . clip_model . CLIPTextModel ,
special_tokens = { " start " : 49406 , " end " : 49407 , " pad " : 49407 } , layer_norm_hidden_state = True , enable_attention_masks = False , return_projected_pooled = True ) : # clip-vit-base-patch32
super ( ) . __init__ ( )
assert layer in self . LAYERS
if textmodel_json_config is None :
textmodel_json_config = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " sd1_clip_config.json " )
with open ( textmodel_json_config ) as f :
config = json . load ( f )
self . transformer = model_class ( config , dtype , device , comfy . ops . manual_cast )
self . num_layers = self . transformer . num_layers
self . max_length = max_length
if freeze :
self . freeze ( )
self . layer = layer
self . layer_idx = None
self . special_tokens = special_tokens
self . logit_scale = torch . nn . Parameter ( torch . tensor ( 4.6055 ) )
self . enable_attention_masks = enable_attention_masks
self . layer_norm_hidden_state = layer_norm_hidden_state
self . return_projected_pooled = return_projected_pooled
if layer == " hidden " :
assert layer_idx is not None
assert abs ( layer_idx ) < self . num_layers
self . set_clip_options ( { " layer " : layer_idx } )
self . options_default = ( self . layer , self . layer_idx , self . return_projected_pooled )
def freeze ( self ) :
self . transformer = self . transformer . eval ( )
#self.train = disabled_train
for param in self . parameters ( ) :
param . requires_grad = False
def set_clip_options ( self , options ) :
layer_idx = options . get ( " layer " , self . layer_idx )
self . return_projected_pooled = options . get ( " projected_pooled " , self . return_projected_pooled )
if layer_idx is None or abs ( layer_idx ) > self . num_layers :
self . layer = " last "
else :
self . layer = " hidden "
self . layer_idx = layer_idx
def reset_clip_options ( self ) :
self . layer = self . options_default [ 0 ]
self . layer_idx = self . options_default [ 1 ]
self . return_projected_pooled = self . options_default [ 2 ]
def set_up_textual_embeddings ( self , tokens , current_embeds ) :
out_tokens = [ ]
next_new_token = token_dict_size = current_embeds . weight . shape [ 0 ] - 1
embedding_weights = [ ]
for x in tokens :
tokens_temp = [ ]
for y in x :
if isinstance ( y , int ) :
if y == token_dict_size : #EOS token
y = - 1
tokens_temp + = [ y ]
else :
if y . shape [ 0 ] == current_embeds . weight . shape [ 1 ] :
embedding_weights + = [ y ]
tokens_temp + = [ next_new_token ]
next_new_token + = 1
else :
logging . warning ( " WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {} " . format ( y . shape [ 0 ] , current_embeds . weight . shape [ 1 ] ) )
while len ( tokens_temp ) < len ( x ) :
tokens_temp + = [ self . special_tokens [ " pad " ] ]
out_tokens + = [ tokens_temp ]
n = token_dict_size
if len ( embedding_weights ) > 0 :
new_embedding = torch . nn . Embedding ( next_new_token + 1 , current_embeds . weight . shape [ 1 ] , device = current_embeds . weight . device , dtype = current_embeds . weight . dtype )
new_embedding . weight [ : token_dict_size ] = current_embeds . weight [ : - 1 ]
for x in embedding_weights :
new_embedding . weight [ n ] = x
n + = 1
new_embedding . weight [ n ] = current_embeds . weight [ - 1 ] #EOS embedding
self . transformer . set_input_embeddings ( new_embedding )
processed_tokens = [ ]
for x in out_tokens :
processed_tokens + = [ list ( map ( lambda a : n if a == - 1 else a , x ) ) ] #The EOS token should always be the largest one
return processed_tokens
def forward ( self , tokens ) :
backup_embeds = self . transformer . get_input_embeddings ( )
device = backup_embeds . weight . device
tokens = self . set_up_textual_embeddings ( tokens , backup_embeds )
tokens = torch . LongTensor ( tokens ) . to ( device )
attention_mask = None
if self . enable_attention_masks :
attention_mask = torch . zeros_like ( tokens )
max_token = self . transformer . get_input_embeddings ( ) . weight . shape [ 0 ] - 1
for x in range ( attention_mask . shape [ 0 ] ) :
for y in range ( attention_mask . shape [ 1 ] ) :
attention_mask [ x , y ] = 1
if tokens [ x , y ] == max_token :
break
outputs = self . transformer ( tokens , attention_mask , intermediate_output = self . layer_idx , final_layer_norm_intermediate = self . layer_norm_hidden_state )
self . transformer . set_input_embeddings ( backup_embeds )
if self . layer == " last " :
z = outputs [ 0 ]
else :
z = outputs [ 1 ]
pooled_output = None
if len ( outputs ) > = 3 :
if not self . return_projected_pooled and len ( outputs ) > = 4 and outputs [ 3 ] is not None :
pooled_output = outputs [ 3 ] . float ( )
elif outputs [ 2 ] is not None :
pooled_output = outputs [ 2 ] . float ( )
return z . float ( ) , pooled_output
def encode ( self , tokens ) :
return self ( tokens )
def load_sd ( self , sd ) :
return self . transformer . load_state_dict ( sd , strict = False )
def parse_parentheses ( string ) :
result = [ ]
current_item = " "
nesting_level = 0
for char in string :
if char == " ( " :
if nesting_level == 0 :
if current_item :
result . append ( current_item )
current_item = " ( "
else :
current_item = " ( "
else :
current_item + = char
nesting_level + = 1
elif char == " ) " :
nesting_level - = 1
if nesting_level == 0 :
result . append ( current_item + " ) " )
current_item = " "
else :
current_item + = char
else :
current_item + = char
if current_item :
result . append ( current_item )
return result
def token_weights ( string , current_weight ) :
a = parse_parentheses ( string )
out = [ ]
for x in a :
weight = current_weight
if len ( x ) > = 2 and x [ - 1 ] == ' ) ' and x [ 0 ] == ' ( ' :
x = x [ 1 : - 1 ]
xx = x . rfind ( " : " )
weight * = 1.1
if xx > 0 :
try :
weight = float ( x [ xx + 1 : ] )
x = x [ : xx ]
except :
pass
out + = token_weights ( x , weight )
else :
out + = [ ( x , current_weight ) ]
return out
def escape_important ( text ) :
text = text . replace ( " \\ ) " , " \0 \1 " )
text = text . replace ( " \\ ( " , " \0 \2 " )
return text
def unescape_important ( text ) :
text = text . replace ( " \0 \1 " , " ) " )
text = text . replace ( " \0 \2 " , " ( " )
return text
def safe_load_embed_zip ( embed_path ) :
with zipfile . ZipFile ( embed_path ) as myzip :
names = list ( filter ( lambda a : " data/ " in a , myzip . namelist ( ) ) )
names . reverse ( )
for n in names :
with myzip . open ( n ) as myfile :
data = myfile . read ( )
number = len ( data ) / / 4
length_embed = 1024 #sd2.x
if number < 768 :
continue
if number % 768 == 0 :
length_embed = 768 #sd1.x
num_embeds = number / / length_embed
embed = torch . frombuffer ( data , dtype = torch . float )
out = embed . reshape ( ( num_embeds , length_embed ) ) . clone ( )
del embed
return out
def expand_directory_list ( directories ) :
dirs = set ( )
for x in directories :
dirs . add ( x )
for root , subdir , file in os . walk ( x , followlinks = True ) :
dirs . add ( root )
return list ( dirs )
def load_embed ( embedding_name , embedding_directory , embedding_size , embed_key = None ) :
if isinstance ( embedding_directory , str ) :
embedding_directory = [ embedding_directory ]
embedding_directory = expand_directory_list ( embedding_directory )
valid_file = None
for embed_dir in embedding_directory :
embed_path = os . path . abspath ( os . path . join ( embed_dir , embedding_name ) )
embed_dir = os . path . abspath ( embed_dir )
try :
if os . path . commonpath ( ( embed_dir , embed_path ) ) != embed_dir :
continue
except :
continue
if not os . path . isfile ( embed_path ) :
extensions = [ ' .safetensors ' , ' .pt ' , ' .bin ' ]
for x in extensions :
t = embed_path + x
if os . path . isfile ( t ) :
valid_file = t
break
else :
valid_file = embed_path
if valid_file is not None :
break
if valid_file is None :
return None
embed_path = valid_file
embed_out = None
try :
if embed_path . lower ( ) . endswith ( " .safetensors " ) :
import safetensors . torch
embed = safetensors . torch . load_file ( embed_path , device = " cpu " )
else :
if ' weights_only ' in torch . load . __code__ . co_varnames :
try :
embed = torch . load ( embed_path , weights_only = True , map_location = " cpu " )
except :
embed_out = safe_load_embed_zip ( embed_path )
else :
embed = torch . load ( embed_path , map_location = " cpu " )
except Exception as e :
logging . warning ( " {} \n \n error loading embedding, skipping loading: {} " . format ( traceback . format_exc ( ) , embedding_name ) )
return None
if embed_out is None :
if ' string_to_param ' in embed :
values = embed [ ' string_to_param ' ] . values ( )
embed_out = next ( iter ( values ) )
elif isinstance ( embed , list ) :
out_list = [ ]
for x in range ( len ( embed ) ) :
for k in embed [ x ] :
t = embed [ x ] [ k ]
if t . shape [ - 1 ] != embedding_size :
continue
out_list . append ( t . reshape ( - 1 , t . shape [ - 1 ] ) )
embed_out = torch . cat ( out_list , dim = 0 )
elif embed_key is not None and embed_key in embed :
embed_out = embed [ embed_key ]
else :
values = embed . values ( )
embed_out = next ( iter ( values ) )
return embed_out
class SDTokenizer :
def __init__ ( self , tokenizer_path = None , max_length = 77 , pad_with_end = True , embedding_directory = None , embedding_size = 768 , embedding_key = ' clip_l ' , tokenizer_class = CLIPTokenizer , has_start_token = True , pad_to_max_length = True , min_length = None ) :
if tokenizer_path is None :
tokenizer_path = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " sd1_tokenizer " )
self . tokenizer = tokenizer_class . from_pretrained ( tokenizer_path )
self . max_length = max_length
self . min_length = min_length
empty = self . tokenizer ( ' ' ) [ " input_ids " ]
if has_start_token :
self . tokens_start = 1
self . start_token = empty [ 0 ]
self . end_token = empty [ 1 ]
else :
self . tokens_start = 0
self . start_token = None
self . end_token = empty [ 0 ]
self . pad_with_end = pad_with_end
self . pad_to_max_length = pad_to_max_length
vocab = self . tokenizer . get_vocab ( )
self . inv_vocab = { v : k for k , v in vocab . items ( ) }
self . embedding_directory = embedding_directory
self . max_word_length = 8
self . embedding_identifier = " embedding: "
self . embedding_size = embedding_size
self . embedding_key = embedding_key
def _try_get_embedding ( self , embedding_name : str ) :
'''
Takes a potential embedding name and tries to retrieve it .
Returns a Tuple consisting of the embedding and any leftover string , embedding can be None .
'''
embed = load_embed ( embedding_name , self . embedding_directory , self . embedding_size , self . embedding_key )
if embed is None :
stripped = embedding_name . strip ( ' , ' )
if len ( stripped ) < len ( embedding_name ) :
embed = load_embed ( stripped , self . embedding_directory , self . embedding_size , self . embedding_key )
return ( embed , embedding_name [ len ( stripped ) : ] )
return ( embed , " " )
def tokenize_with_weights ( self , text : str , return_word_ids = False ) :
'''
Takes a prompt and converts it to a list of ( token , weight , word id ) elements .
Tokens can both be integer tokens and pre computed CLIP tensors .
Word id values are unique per word and embedding , where the id 0 is reserved for non word tokens .
Returned list has the dimensions NxM where M is the input size of CLIP
'''
if self . pad_with_end :
pad_token = self . end_token
else :
pad_token = 0
text = escape_important ( text )
parsed_weights = token_weights ( text , 1.0 )
#tokenize words
tokens = [ ]
for weighted_segment , weight in parsed_weights :
to_tokenize = unescape_important ( weighted_segment ) . replace ( " \n " , " " ) . split ( ' ' )
to_tokenize = [ x for x in to_tokenize if x != " " ]
for word in to_tokenize :
#if we find an embedding, deal with the embedding
if word . startswith ( self . embedding_identifier ) and self . embedding_directory is not None :
embedding_name = word [ len ( self . embedding_identifier ) : ] . strip ( ' \n ' )
embed , leftover = self . _try_get_embedding ( embedding_name )
if embed is None :
logging . warning ( f " warning, embedding: { embedding_name } does not exist, ignoring " )
else :
if len ( embed . shape ) == 1 :
tokens . append ( [ ( embed , weight ) ] )
else :
tokens . append ( [ ( embed [ x ] , weight ) for x in range ( embed . shape [ 0 ] ) ] )
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != " " :
word = leftover
else :
continue
#parse word
tokens . append ( [ ( t , weight ) for t in self . tokenizer ( word ) [ " input_ids " ] [ self . tokens_start : - 1 ] ] )
#reshape token array to CLIP input size
batched_tokens = [ ]
batch = [ ]
if self . start_token is not None :
batch . append ( ( self . start_token , 1.0 , 0 ) )
batched_tokens . append ( batch )
for i , t_group in enumerate ( tokens ) :
#determine if we're going to try and keep the tokens in a single batch
is_large = len ( t_group ) > = self . max_word_length
while len ( t_group ) > 0 :
if len ( t_group ) + len ( batch ) > self . max_length - 1 :
remaining_length = self . max_length - len ( batch ) - 1
#break word in two and add end token
if is_large :
batch . extend ( [ ( t , w , i + 1 ) for t , w in t_group [ : remaining_length ] ] )
batch . append ( ( self . end_token , 1.0 , 0 ) )
t_group = t_group [ remaining_length : ]
#add end token and pad
else :
batch . append ( ( self . end_token , 1.0 , 0 ) )
if self . pad_to_max_length :
batch . extend ( [ ( pad_token , 1.0 , 0 ) ] * ( remaining_length ) )
#start new batch
batch = [ ]
if self . start_token is not None :
batch . append ( ( self . start_token , 1.0 , 0 ) )
batched_tokens . append ( batch )
else :
batch . extend ( [ ( t , w , i + 1 ) for t , w in t_group ] )
t_group = [ ]
#fill last batch
batch . append ( ( self . end_token , 1.0 , 0 ) )
if self . pad_to_max_length :
batch . extend ( [ ( pad_token , 1.0 , 0 ) ] * ( self . max_length - len ( batch ) ) )
if self . min_length is not None and len ( batch ) < self . min_length :
batch . extend ( [ ( pad_token , 1.0 , 0 ) ] * ( self . min_length - len ( batch ) ) )
if not return_word_ids :
batched_tokens = [ [ ( t , w ) for t , w , _ in x ] for x in batched_tokens ]
return batched_tokens
def untokenize ( self , token_weight_pair ) :
return list ( map ( lambda a : ( a , self . inv_vocab [ a [ 0 ] ] ) , token_weight_pair ) )
class SD1Tokenizer :
def __init__ ( self , embedding_directory = None , clip_name = " l " , tokenizer = SDTokenizer ) :
self . clip_name = clip_name
self . clip = " clip_ {} " . format ( self . clip_name )
setattr ( self , self . clip , tokenizer ( embedding_directory = embedding_directory ) )
def tokenize_with_weights ( self , text : str , return_word_ids = False ) :
out = { }
out [ self . clip_name ] = getattr ( self , self . clip ) . tokenize_with_weights ( text , return_word_ids )
return out
def untokenize ( self , token_weight_pair ) :
return getattr ( self , self . clip ) . untokenize ( token_weight_pair )
class SD1ClipModel ( torch . nn . Module ) :
def __init__ ( self , device = " cpu " , dtype = None , clip_name = " l " , clip_model = SDClipModel , * * kwargs ) :
super ( ) . __init__ ( )
self . clip_name = clip_name
self . clip = " clip_ {} " . format ( self . clip_name )
setattr ( self , self . clip , clip_model ( device = device , dtype = dtype , * * kwargs ) )
def set_clip_options ( self , options ) :
getattr ( self , self . clip ) . set_clip_options ( options )
def reset_clip_options ( self ) :
getattr ( self , self . clip ) . reset_clip_options ( )
def encode_token_weights ( self , token_weight_pairs ) :
token_weight_pairs = token_weight_pairs [ self . clip_name ]
out , pooled = getattr ( self , self . clip ) . encode_token_weights ( token_weight_pairs )
return out , pooled
def load_sd ( self , sd ) :
return getattr ( self , self . clip ) . load_sd ( sd )