From 94a7c895f41944d60fc3f99355064fac8347b006 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Mar 2023 03:40:12 -0400 Subject: [PATCH] Add loha support. --- README.md | 2 +- comfy/sd.py | 52 ++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index d83174e3..4ff03133 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Works even if you don't have a GPU with: ```--cpu``` (slow) - Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models. - Embeddings/Textual inversion -- [Loras (regular and locon)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - Loading full workflows (with seeds) from generated PNG files. - Saving/Loading workflows as Json files. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. diff --git a/comfy/sd.py b/comfy/sd.py index b344cbec..714fa66b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -126,15 +126,17 @@ def load_lora(path, to_load): patch_dict = {} loaded_keys = set() for x in to_load: + alpha_name = "{}.alpha".format(x) + alpha = None + if alpha_name in lora.keys(): + alpha = lora[alpha_name].item() + loaded_keys.add(alpha_name) + A_name = "{}.lora_up.weight".format(x) B_name = "{}.lora_down.weight".format(x) - alpha_name = "{}.alpha".format(x) mid_name = "{}.lora_mid.weight".format(x) + if A_name in lora.keys(): - alpha = None - if alpha_name in lora.keys(): - alpha = lora[alpha_name].item() - loaded_keys.add(alpha_name) mid = None if mid_name in lora.keys(): mid = lora[mid_name] @@ -142,6 +144,18 @@ def load_lora(path, to_load): patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) loaded_keys.add(A_name) loaded_keys.add(B_name) + + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + if hada_w1_a_name in lora.keys(): + patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name]) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + for x in lora.keys(): if x not in loaded_keys: print("lora key not loaded", x) @@ -280,15 +294,25 @@ class ModelPatcher: self.backup[key] = weight.clone() alpha = p[0] - mat1 = v[0] - mat2 = v[1] - if v[2] is not None: - alpha *= v[2] / mat2.shape[0] - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) - weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + + if len(v) == 4: #lora/locon + mat1 = v[0] + mat2 = v[1] + if v[2] is not None: + alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + #locon mid weights, hopefully the math is fine because I didn't properly test it + final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) + weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + else: #loha + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha *= v[2] / w1b.shape[0] + w2a = v[3] + w2b = v[4] + weight += (alpha * torch.mm(w1a.float(), w1b.float()) * torch.mm(w2a.float(), w2b.float())).reshape(weight.shape).type(weight.dtype).to(weight.device) return self.model def unpatch_model(self): model_sd = self.model.state_dict()