Browse Source

dora_scale support for lora file.

pull/3157/head
comfyanonymous 8 months ago
parent
commit
ae77590b4e
  1. 14
      comfy/lora.py
  2. 25
      comfy/model_patcher.py

14
comfy/lora.py

@ -21,6 +21,12 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
@ -44,7 +50,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@ -65,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@ -117,7 +123,7 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
#glora
a1_name = "{}.a1.weight".format(x)
@ -125,7 +131,7 @@ def load_lora(lora, to_load):
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)

25
comfy/model_patcher.py

@ -7,6 +7,18 @@ import uuid
import comfy.utils
import comfy.model_management
def apply_weight_decompose(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
@ -309,6 +321,7 @@ class ModelPatcher:
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
@ -318,6 +331,8 @@ class ModelPatcher:
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
@ -328,6 +343,7 @@ class ModelPatcher:
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
@ -357,6 +373,8 @@ class ModelPatcher:
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
@ -366,6 +384,7 @@ class ModelPatcher:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
@ -386,12 +405,16 @@ class ModelPatcher:
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
@ -399,6 +422,8 @@ class ModelPatcher:
try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:

Loading…
Cancel
Save