@ -8,32 +8,54 @@ import zipfile
from . import model_management
import contextlib
def gen_empty_tokens ( special_tokens , length ) :
start_token = special_tokens . get ( " start " , None )
end_token = special_tokens . get ( " end " , None )
pad_token = special_tokens . get ( " pad " )
output = [ ]
if start_token is not None :
output . append ( start_token )
if end_token is not None :
output . append ( end_token )
output + = [ pad_token ] * ( length - len ( output ) )
return output
class ClipTokenWeightEncoder :
def encode_token_weights ( self , token_weight_pairs ) :
to_encode = list ( self . empty_tokens )
to_encode = list ( )
max_token_len = 0
has_weights = False
for x in token_weight_pairs :
tokens = list ( map ( lambda a : a [ 0 ] , x ) )
max_token_len = max ( len ( tokens ) , max_token_len )
has_weights = has_weights or not all ( map ( lambda a : a [ 1 ] == 1.0 , x ) )
to_encode . append ( tokens )
sections = len ( to_encode )
if has_weights or sections == 0 :
to_encode . append ( gen_empty_tokens ( self . special_tokens , max_token_len ) )
out , pooled = self . encode ( to_encode )
z_empty = out [ 0 : 1 ]
if pooled . shape [ 0 ] > 1 :
first_pooled = pooled [ 1 : 2 ]
if pooled is not None :
first_pooled = pooled [ 0 : 1 ] . cpu ( )
else :
first_pooled = pooled [ 0 : 1 ]
first_pooled = pooled
output = [ ]
for k in range ( 1 , out . shape [ 0 ] ) :
for k in range ( 0 , sections ) :
z = out [ k : k + 1 ]
if has_weights :
z_empty = out [ - 1 ]
for i in range ( len ( z ) ) :
for j in range ( len ( z [ i ] ) ) :
weight = token_weight_pairs [ k - 1 ] [ j ] [ 1 ]
z [ i ] [ j ] = ( z [ i ] [ j ] - z_empty [ 0 ] [ j ] ) * weight + z_empty [ 0 ] [ j ]
weight = token_weight_pairs [ k ] [ j ] [ 1 ]
if weight != 1.0 :
z [ i ] [ j ] = ( z [ i ] [ j ] - z_empty [ j ] ) * weight + z_empty [ j ]
output . append ( z )
if ( len ( output ) == 0 ) :
return z_empty . cpu ( ) , first_pooled . cpu ( )
return torch . cat ( output , dim = - 2 ) . cpu ( ) , first_pooled . cpu ( )
return out [ - 1 : ] . cpu ( ) , first_pooled
return torch . cat ( output , dim = - 2 ) . cpu ( ) , first_pooled
class SDClipModel ( torch . nn . Module , ClipTokenWeightEncoder ) :
""" Uses the CLIP transformer encoder for text (from huggingface) """
@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
" hidden "
]
def __init__ ( self , version = " openai/clip-vit-large-patch14 " , device = " cpu " , max_length = 77 ,
freeze = True , layer = " last " , layer_idx = None , textmodel_json_config = None , textmodel_path = None , dtype = None ) : # clip-vit-base-patch32
freeze = True , layer = " last " , layer_idx = None , textmodel_json_config = None , textmodel_path = None , dtype = None ,
special_tokens = { " start " : 49406 , " end " : 49407 , " pad " : 49407 } , layer_norm_hidden_state = True , config_class = CLIPTextConfig ,
model_class = CLIPTextModel , inner_name = " text_model " ) : # clip-vit-base-patch32
super ( ) . __init__ ( )
assert layer in self . LAYERS
self . num_layers = 12
if textmodel_path is not None :
self . transformer = CLIPTextModel . from_pretrained ( textmodel_path )
self . transformer = model_class . from_pretrained ( textmodel_path )
else :
if textmodel_json_config is None :
textmodel_json_config = os . path . join ( os . path . dirname ( os . path . realpath ( __file__ ) ) , " sd1_clip_config.json " )
config = CLIPTextConfig . from_json_file ( textmodel_json_config )
config = config_class . from_json_file ( textmodel_json_config )
self . num_layers = config . num_hidden_layers
with comfy . ops . use_comfy_ops ( device , dtype ) :
with modeling_utils . no_init_weights ( ) :
self . transformer = CLIPTextModel ( config )
self . transformer = model_class ( config )
self . inner_name = inner_name
if dtype is not None :
self . transformer . to ( dtype )
self . transformer . text_model . embeddings . token_embedding . to ( torch . float32 )
self . transformer . text_model . embeddings . position_embedding . to ( torch . float32 )
inner_model = getattr ( self . transformer , self . inner_name )
if hasattr ( inner_model , " embeddings " ) :
inner_model . embeddings . to ( torch . float32 )
else :
self . transformer . set_input_embeddings ( self . transformer . get_input_embeddings ( ) . to ( torch . float32 ) )
self . max_length = max_length
if freeze :
self . freeze ( )
self . layer = layer
self . layer_idx = None
self . empty_tokens = [ [ 49406 ] + [ 49407 ] * 76 ]
self . special_tokens = special_tokens
self . text_projection = torch . nn . Parameter ( torch . eye ( self . transformer . get_input_embeddings ( ) . weight . shape [ 1 ] ) )
self . logit_scale = torch . nn . Parameter ( torch . tensor ( 4.6055 ) )
self . enable_attention_masks = False
self . layer_norm_hidden_state = Tru e
self . layer_norm_hidden_state = layer_norm_hidden_stat e
if layer == " hidden " :
assert layer_idx is not None
assert abs ( layer_idx ) < = self . num_layers
@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else :
print ( " WARNING: shape mismatch when trying to apply embedding, embedding will be ignored " , y . shape [ 0 ] , current_embeds . weight . shape [ 1 ] )
while len ( tokens_temp ) < len ( x ) :
tokens_temp + = [ self . empty_tokens [ 0 ] [ - 1 ] ]
tokens_temp + = [ self . special_tokens [ " pad " ] ]
out_tokens + = [ tokens_temp ]
n = token_dict_size
@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self . set_up_textual_embeddings ( tokens , backup_embeds )
tokens = torch . LongTensor ( tokens ) . to ( device )
if self . transformer . text_model . final_layer_norm . weight . dtype != torch . float32 :
if getattr ( self . transformer , self . inner_name ) . final_layer_norm . weight . dtype != torch . float32 :
precision_scope = torch . autocast
else :
precision_scope = lambda a , b : contextlib . nullcontext ( a )
@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else :
z = outputs . hidden_states [ self . layer_idx ]
if self . layer_norm_hidden_state :
z = self . transformer . text_model . final_layer_norm ( z )
z = getattr ( self . transformer , self . inner_name ) . final_layer_norm ( z )
if hasattr ( outputs , " pooler_output " ) :
pooled_output = outputs . pooler_output . float ( )
else :
pooled_output = None
pooled_output = outputs . pooler_output
if self . text_projection is not None :
if self . text_projection is not None and pooled_output is not None :
pooled_output = pooled_output . float ( ) . to ( self . text_projection . device ) @ self . text_projection . float ( )
return z . float ( ) , pooled_output . float ( )
return z . float ( ) , pooled_output
def encode ( self , tokens ) :
return self ( tokens )