|
|
|
@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
self.enable_attention_masks = enable_attention_masks |
|
|
|
|
|
|
|
|
|
self.layer_norm_hidden_state = layer_norm_hidden_state |
|
|
|
|
self.return_projected_pooled = True |
|
|
|
|
|
|
|
|
|
if layer == "hidden": |
|
|
|
|
assert layer_idx is not None |
|
|
|
|
assert abs(layer_idx) < self.num_layers |
|
|
|
|
self.clip_layer(layer_idx) |
|
|
|
|
self.layer_default = (self.layer, self.layer_idx) |
|
|
|
|
self.set_clip_options({"layer": layer_idx}) |
|
|
|
|
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) |
|
|
|
|
|
|
|
|
|
def freeze(self): |
|
|
|
|
self.transformer = self.transformer.eval() |
|
|
|
@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
for param in self.parameters(): |
|
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
def clip_layer(self, layer_idx): |
|
|
|
|
if abs(layer_idx) > self.num_layers: |
|
|
|
|
def set_clip_options(self, options): |
|
|
|
|
layer_idx = options.get("layer", self.layer_idx) |
|
|
|
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) |
|
|
|
|
if layer_idx is None or abs(layer_idx) > self.num_layers: |
|
|
|
|
self.layer = "last" |
|
|
|
|
else: |
|
|
|
|
self.layer = "hidden" |
|
|
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
|
|
|
|
def reset_clip_layer(self): |
|
|
|
|
self.layer = self.layer_default[0] |
|
|
|
|
self.layer_idx = self.layer_default[1] |
|
|
|
|
def reset_clip_options(self): |
|
|
|
|
self.layer = self.options_default[0] |
|
|
|
|
self.layer_idx = self.options_default[1] |
|
|
|
|
self.return_projected_pooled = self.options_default[2] |
|
|
|
|
|
|
|
|
|
def set_up_textual_embeddings(self, tokens, current_embeds): |
|
|
|
|
out_tokens = [] |
|
|
|
@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
else: |
|
|
|
|
z = outputs[1] |
|
|
|
|
|
|
|
|
|
if outputs[2] is not None: |
|
|
|
|
pooled_output = outputs[2].float() |
|
|
|
|
else: |
|
|
|
|
pooled_output = None |
|
|
|
|
pooled_output = None |
|
|
|
|
if len(outputs) >= 3: |
|
|
|
|
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: |
|
|
|
|
pooled_output = outputs[3].float() |
|
|
|
|
elif outputs[2] is not None: |
|
|
|
|
pooled_output = outputs[2].float() |
|
|
|
|
|
|
|
|
|
return z.float(), pooled_output |
|
|
|
|
|
|
|
|
@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
|
|
|
|
|
self.clip = "clip_{}".format(self.clip_name) |
|
|
|
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) |
|
|
|
|
|
|
|
|
|
def clip_layer(self, layer_idx): |
|
|
|
|
getattr(self, self.clip).clip_layer(layer_idx) |
|
|
|
|
def set_clip_options(self, options): |
|
|
|
|
getattr(self, self.clip).set_clip_options(options) |
|
|
|
|
|
|
|
|
|
def reset_clip_layer(self): |
|
|
|
|
getattr(self, self.clip).reset_clip_layer() |
|
|
|
|
def reset_clip_options(self): |
|
|
|
|
getattr(self, self.clip).reset_clip_options() |
|
|
|
|
|
|
|
|
|
def encode_token_weights(self, token_weight_pairs): |
|
|
|
|
token_weight_pairs = token_weight_pairs[self.clip_name] |
|
|
|
|