|
|
|
@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
with open(textmodel_json_config) as f: |
|
|
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
|
self.transformer = model_class(config, dtype, device, comfy.ops) |
|
|
|
|
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast) |
|
|
|
|
self.num_layers = self.transformer.num_layers |
|
|
|
|
|
|
|
|
|
self.max_length = max_length |
|
|
|
@ -160,37 +160,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) |
|
|
|
|
tokens = torch.LongTensor(tokens).to(device) |
|
|
|
|
|
|
|
|
|
if self.transformer.dtype != torch.float32: |
|
|
|
|
precision_scope = torch.autocast |
|
|
|
|
attention_mask = None |
|
|
|
|
if self.enable_attention_masks: |
|
|
|
|
attention_mask = torch.zeros_like(tokens) |
|
|
|
|
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 |
|
|
|
|
for x in range(attention_mask.shape[0]): |
|
|
|
|
for y in range(attention_mask.shape[1]): |
|
|
|
|
attention_mask[x, y] = 1 |
|
|
|
|
if tokens[x, y] == max_token: |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) |
|
|
|
|
self.transformer.set_input_embeddings(backup_embeds) |
|
|
|
|
|
|
|
|
|
if self.layer == "last": |
|
|
|
|
z = outputs[0] |
|
|
|
|
else: |
|
|
|
|
precision_scope = lambda a, dtype: contextlib.nullcontext(a) |
|
|
|
|
|
|
|
|
|
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): |
|
|
|
|
attention_mask = None |
|
|
|
|
if self.enable_attention_masks: |
|
|
|
|
attention_mask = torch.zeros_like(tokens) |
|
|
|
|
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 |
|
|
|
|
for x in range(attention_mask.shape[0]): |
|
|
|
|
for y in range(attention_mask.shape[1]): |
|
|
|
|
attention_mask[x, y] = 1 |
|
|
|
|
if tokens[x, y] == max_token: |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) |
|
|
|
|
self.transformer.set_input_embeddings(backup_embeds) |
|
|
|
|
|
|
|
|
|
if self.layer == "last": |
|
|
|
|
z = outputs[0] |
|
|
|
|
else: |
|
|
|
|
z = outputs[1] |
|
|
|
|
z = outputs[1] |
|
|
|
|
|
|
|
|
|
if outputs[2] is not None: |
|
|
|
|
pooled_output = outputs[2].float() |
|
|
|
|
else: |
|
|
|
|
pooled_output = None |
|
|
|
|
if outputs[2] is not None: |
|
|
|
|
pooled_output = outputs[2].float() |
|
|
|
|
else: |
|
|
|
|
pooled_output = None |
|
|
|
|
|
|
|
|
|
if self.text_projection is not None and pooled_output is not None: |
|
|
|
|
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() |
|
|
|
|
if self.text_projection is not None and pooled_output is not None: |
|
|
|
|
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() |
|
|
|
|
return z.float(), pooled_output |
|
|
|
|
|
|
|
|
|
def encode(self, tokens): |
|
|
|
|