@ -63,9 +63,38 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self . layer = " hidden "
self . layer = " hidden "
self . layer_idx = layer_idx
self . layer_idx = layer_idx
def set_up_textual_embeddings ( self , tokens , current_embeds ) :
out_tokens = [ ]
next_new_token = token_dict_size = current_embeds . weight . shape [ 0 ]
embedding_weights = [ ]
for x in tokens :
tokens_temp = [ ]
for y in x :
if isinstance ( y , int ) :
tokens_temp + = [ y ]
else :
embedding_weights + = [ y ]
tokens_temp + = [ next_new_token ]
next_new_token + = 1
out_tokens + = [ tokens_temp ]
if len ( embedding_weights ) > 0 :
new_embedding = torch . nn . Embedding ( next_new_token , current_embeds . weight . shape [ 1 ] )
new_embedding . weight [ : token_dict_size ] = current_embeds . weight [ : ]
n = token_dict_size
for x in embedding_weights :
new_embedding . weight [ n ] = x
n + = 1
self . transformer . set_input_embeddings ( new_embedding )
return out_tokens
def forward ( self , tokens ) :
def forward ( self , tokens ) :
backup_embeds = self . transformer . get_input_embeddings ( )
tokens = self . set_up_textual_embeddings ( tokens , backup_embeds )
tokens = torch . LongTensor ( tokens ) . to ( self . device )
tokens = torch . LongTensor ( tokens ) . to ( self . device )
outputs = self . transformer ( input_ids = tokens , output_hidden_states = self . layer == " hidden " )
outputs = self . transformer ( input_ids = tokens , output_hidden_states = self . layer == " hidden " )
self . transformer . set_input_embeddings ( backup_embeds )
if self . layer == " last " :
if self . layer == " last " :
z = outputs . last_hidden_state
z = outputs . last_hidden_state
@ -138,18 +167,49 @@ def unescape_important(text):
text = text . replace ( " \0 \2 " , " ( " )
text = text . replace ( " \0 \2 " , " ( " )
return text
return text
def load_embed ( embedding_name , embedding_directory ) :
embed_path = os . path . join ( embedding_directory , embedding_name )
if not os . path . isfile ( embed_path ) :
extensions = [ ' .safetensors ' , ' .pt ' , ' .bin ' ]
valid_file = None
for x in extensions :
t = embed_path + x
if os . path . isfile ( t ) :
valid_file = t
break
if valid_file is None :
print ( " warning, embedding {} does not exist, ignoring " . format ( embed_path ) )
return None
else :
embed_path = valid_file
if embed_path . lower ( ) . endswith ( " .safetensors " ) :
import safetensors . torch
embed = safetensors . torch . load_file ( embed_path , device = " cpu " )
else :
embed = torch . load ( embed_path , weights_only = True , map_location = " cpu " )
if ' string_to_param ' in embed :
values = embed [ ' string_to_param ' ] . values ( )
else :
values = embed . values ( )
return next ( iter ( values ) )
class SD1Tokenizer :
class SD1Tokenizer :
def __init__ ( self , tokenizer_path = None , max_length = 77 , pad_with_end = True ) :
def __init__ ( self , tokenizer_path = None , max_length = 77 , pad_with_end = True , embedding_directory = None ) :
if tokenizer_path is None :
if tokenizer_path is None :
tokenizer_path = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " sd1_tokenizer " )
tokenizer_path = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " sd1_tokenizer " )
self . tokenizer = CLIPTokenizer . from_pretrained ( tokenizer_path )
self . tokenizer = CLIPTokenizer . from_pretrained ( tokenizer_path )
self . max_length = max_length
self . max_length = max_length
self . max_tokens_per_section = self . max_length - 2
empty = self . tokenizer ( ' ' ) [ " input_ids " ]
empty = self . tokenizer ( ' ' ) [ " input_ids " ]
self . start_token = empty [ 0 ]
self . start_token = empty [ 0 ]
self . end_token = empty [ 1 ]
self . end_token = empty [ 1 ]
self . pad_with_end = pad_with_end
self . pad_with_end = pad_with_end
vocab = self . tokenizer . get_vocab ( )
vocab = self . tokenizer . get_vocab ( )
self . inv_vocab = { v : k for k , v in vocab . items ( ) }
self . inv_vocab = { v : k for k , v in vocab . items ( ) }
self . embedding_directory = embedding_directory
self . max_word_length = 8
def tokenize_with_weights ( self , text ) :
def tokenize_with_weights ( self , text ) :
text = escape_important ( text )
text = escape_important ( text )
@ -157,13 +217,34 @@ class SD1Tokenizer:
tokens = [ ]
tokens = [ ]
for t in parsed_weights :
for t in parsed_weights :
tt = self . tokenizer ( unescape_important ( t [ 0 ] ) ) [ " input_ids " ] [ 1 : - 1 ]
to_tokenize = unescape_important ( t [ 0 ] ) . split ( ' ' )
for word in to_tokenize :
temp_tokens = [ ]
embedding_identifier = " embedding: "
if word . startswith ( embedding_identifier ) and self . embedding_directory is not None :
embedding_name = word [ len ( embedding_identifier ) : ] . strip ( ' \n ' )
embed = load_embed ( embedding_name , self . embedding_directory )
if embed is not None :
if len ( embed . shape ) == 1 :
temp_tokens + = [ ( embed , t [ 1 ] ) ]
else :
for x in range ( embed . shape [ 0 ] ) :
temp_tokens + = [ ( embed [ x ] , t [ 1 ] ) ]
elif len ( word ) > 0 :
tt = self . tokenizer ( word ) [ " input_ids " ] [ 1 : - 1 ]
for x in tt :
for x in tt :
tokens + = [ ( x , t [ 1 ] ) ]
temp_tokens + = [ ( x , t [ 1 ] ) ]
tokens_left = self . max_tokens_per_section - ( len ( tokens ) % self . max_tokens_per_section )
#try not to split words in different sections
if tokens_left < len ( temp_tokens ) and len ( temp_tokens ) < ( self . max_word_length ) :
for x in range ( tokens_left ) :
tokens + = [ ( self . end_token , 1.0 ) ]
tokens + = temp_tokens
out_tokens = [ ]
out_tokens = [ ]
for x in range ( 0 , len ( tokens ) , self . max_length - 2 ) :
for x in range ( 0 , len ( tokens ) , self . max_tokens_per_section ) :
o_token = [ ( self . start_token , 1.0 ) ] + tokens [ x : min ( self . max_length - 2 + x , len ( tokens ) ) ]
o_token = [ ( self . start_token , 1.0 ) ] + tokens [ x : min ( self . max_tokens_per_section + x , len ( tokens ) ) ]
o_token + = [ ( self . end_token , 1.0 ) ]
o_token + = [ ( self . end_token , 1.0 ) ]
if self . pad_with_end :
if self . pad_with_end :
o_token + = [ ( self . end_token , 1.0 ) ] * ( self . max_length - len ( o_token ) )
o_token + = [ ( self . end_token , 1.0 ) ] * ( self . max_length - len ( o_token ) )