|
|
@ -7,6 +7,18 @@ import uuid |
|
|
|
import comfy.utils |
|
|
|
import comfy.utils |
|
|
|
import comfy.model_management |
|
|
|
import comfy.model_management |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_weight_decompose(dora_scale, weight): |
|
|
|
|
|
|
|
weight_norm = ( |
|
|
|
|
|
|
|
weight.transpose(0, 1) |
|
|
|
|
|
|
|
.reshape(weight.shape[1], -1) |
|
|
|
|
|
|
|
.norm(dim=1, keepdim=True) |
|
|
|
|
|
|
|
.reshape(weight.shape[1], *[1] * (weight.dim() - 1)) |
|
|
|
|
|
|
|
.transpose(0, 1) |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return weight * (dora_scale / weight_norm) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelPatcher: |
|
|
|
class ModelPatcher: |
|
|
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): |
|
|
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): |
|
|
|
self.size = size |
|
|
|
self.size = size |
|
|
@ -309,6 +321,7 @@ class ModelPatcher: |
|
|
|
elif patch_type == "lora": #lora/locon |
|
|
|
elif patch_type == "lora": #lora/locon |
|
|
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) |
|
|
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) |
|
|
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) |
|
|
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) |
|
|
|
|
|
|
|
dora_scale = v[4] |
|
|
|
if v[2] is not None: |
|
|
|
if v[2] is not None: |
|
|
|
alpha *= v[2] / mat2.shape[0] |
|
|
|
alpha *= v[2] / mat2.shape[0] |
|
|
|
if v[3] is not None: |
|
|
|
if v[3] is not None: |
|
|
@ -318,6 +331,8 @@ class ModelPatcher: |
|
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) |
|
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) |
|
|
|
try: |
|
|
|
try: |
|
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) |
|
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) |
|
|
|
|
|
|
|
if dora_scale is not None: |
|
|
|
|
|
|
|
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) |
|
|
|
except Exception as e: |
|
|
|
except Exception as e: |
|
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
|
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
|
|
|
elif patch_type == "lokr": |
|
|
|
elif patch_type == "lokr": |
|
|
@ -328,6 +343,7 @@ class ModelPatcher: |
|
|
|
w2_a = v[5] |
|
|
|
w2_a = v[5] |
|
|
|
w2_b = v[6] |
|
|
|
w2_b = v[6] |
|
|
|
t2 = v[7] |
|
|
|
t2 = v[7] |
|
|
|
|
|
|
|
dora_scale = v[8] |
|
|
|
dim = None |
|
|
|
dim = None |
|
|
|
|
|
|
|
|
|
|
|
if w1 is None: |
|
|
|
if w1 is None: |
|
|
@ -357,6 +373,8 @@ class ModelPatcher: |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
try: |
|
|
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) |
|
|
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) |
|
|
|
|
|
|
|
if dora_scale is not None: |
|
|
|
|
|
|
|
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) |
|
|
|
except Exception as e: |
|
|
|
except Exception as e: |
|
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
|
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
|
|
|
elif patch_type == "loha": |
|
|
|
elif patch_type == "loha": |
|
|
@ -366,6 +384,7 @@ class ModelPatcher: |
|
|
|
alpha *= v[2] / w1b.shape[0] |
|
|
|
alpha *= v[2] / w1b.shape[0] |
|
|
|
w2a = v[3] |
|
|
|
w2a = v[3] |
|
|
|
w2b = v[4] |
|
|
|
w2b = v[4] |
|
|
|
|
|
|
|
dora_scale = v[7] |
|
|
|
if v[5] is not None: #cp decomposition |
|
|
|
if v[5] is not None: #cp decomposition |
|
|
|
t1 = v[5] |
|
|
|
t1 = v[5] |
|
|
|
t2 = v[6] |
|
|
|
t2 = v[6] |
|
|
@ -386,12 +405,16 @@ class ModelPatcher: |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
try: |
|
|
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) |
|
|
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) |
|
|
|
|
|
|
|
if dora_scale is not None: |
|
|
|
|
|
|
|
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) |
|
|
|
except Exception as e: |
|
|
|
except Exception as e: |
|
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
|
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
|
|
|
elif patch_type == "glora": |
|
|
|
elif patch_type == "glora": |
|
|
|
if v[4] is not None: |
|
|
|
if v[4] is not None: |
|
|
|
alpha *= v[4] / v[0].shape[0] |
|
|
|
alpha *= v[4] / v[0].shape[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dora_scale = v[5] |
|
|
|
|
|
|
|
|
|
|
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) |
|
|
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) |
|
|
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) |
|
|
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) |
|
|
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) |
|
|
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) |
|
|
@ -399,6 +422,8 @@ class ModelPatcher: |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
try: |
|
|
|
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) |
|
|
|
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) |
|
|
|
|
|
|
|
if dora_scale is not None: |
|
|
|
|
|
|
|
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) |
|
|
|
except Exception as e: |
|
|
|
except Exception as e: |
|
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
|
|
|
logging.error("ERROR {} {} {}".format(patch_type, key, e)) |
|
|
|
else: |
|
|
|
else: |
|
|
|