You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
4.1 KiB
110 lines
4.1 KiB
import comfy.utils |
|
import folder_paths |
|
import torch |
|
|
|
def load_hypernetwork_patch(path, strength): |
|
sd = comfy.utils.load_torch_file(path, safe_load=True) |
|
activation_func = sd.get('activation_func', 'linear') |
|
is_layer_norm = sd.get('is_layer_norm', False) |
|
use_dropout = sd.get('use_dropout', False) |
|
activate_output = sd.get('activate_output', False) |
|
last_layer_dropout = sd.get('last_layer_dropout', False) |
|
|
|
valid_activation = { |
|
"linear": torch.nn.Identity, |
|
"relu": torch.nn.ReLU, |
|
"leakyrelu": torch.nn.LeakyReLU, |
|
"elu": torch.nn.ELU, |
|
"swish": torch.nn.Hardswish, |
|
"tanh": torch.nn.Tanh, |
|
"sigmoid": torch.nn.Sigmoid, |
|
"softsign": torch.nn.Softsign, |
|
} |
|
|
|
if activation_func not in valid_activation: |
|
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) |
|
return None |
|
|
|
out = {} |
|
|
|
for d in sd: |
|
try: |
|
dim = int(d) |
|
except: |
|
continue |
|
|
|
output = [] |
|
for index in [0, 1]: |
|
attn_weights = sd[dim][index] |
|
keys = attn_weights.keys() |
|
|
|
linears = filter(lambda a: a.endswith(".weight"), keys) |
|
linears = list(map(lambda a: a[:-len(".weight")], linears)) |
|
layers = [] |
|
|
|
for i in range(len(linears)): |
|
lin_name = linears[i] |
|
last_layer = (i == (len(linears) - 1)) |
|
penultimate_layer = (i == (len(linears) - 2)) |
|
|
|
lin_weight = attn_weights['{}.weight'.format(lin_name)] |
|
lin_bias = attn_weights['{}.bias'.format(lin_name)] |
|
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) |
|
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) |
|
layers.append(layer) |
|
if activation_func != "linear": |
|
if (not last_layer) or (activate_output): |
|
layers.append(valid_activation[activation_func]()) |
|
if is_layer_norm: |
|
layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) |
|
if use_dropout: |
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout): |
|
layers.append(torch.nn.Dropout(p=0.3)) |
|
|
|
output.append(torch.nn.Sequential(*layers)) |
|
out[dim] = torch.nn.ModuleList(output) |
|
|
|
class hypernetwork_patch: |
|
def __init__(self, hypernet, strength): |
|
self.hypernet = hypernet |
|
self.strength = strength |
|
def __call__(self, q, k, v, extra_options): |
|
dim = k.shape[-1] |
|
if dim in self.hypernet: |
|
hn = self.hypernet[dim] |
|
k = k + hn[0](k) * self.strength |
|
v = v + hn[1](v) * self.strength |
|
|
|
return q, k, v |
|
|
|
def to(self, device): |
|
for d in self.hypernet.keys(): |
|
self.hypernet[d] = self.hypernet[d].to(device) |
|
return self |
|
|
|
return hypernetwork_patch(out, strength) |
|
|
|
class HypernetworkLoader: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { "model": ("MODEL",), |
|
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), |
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), |
|
}} |
|
RETURN_TYPES = ("MODEL",) |
|
FUNCTION = "load_hypernetwork" |
|
|
|
CATEGORY = "loaders" |
|
|
|
def load_hypernetwork(self, model, hypernetwork_name, strength): |
|
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) |
|
model_hypernetwork = model.clone() |
|
patch = load_hypernetwork_patch(hypernetwork_path, strength) |
|
if patch is not None: |
|
model_hypernetwork.set_model_attn1_patch(patch) |
|
model_hypernetwork.set_model_attn2_patch(patch) |
|
return (model_hypernetwork,) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"HypernetworkLoader": HypernetworkLoader |
|
}
|
|
|