diff --git a/comfy/sd.py b/comfy/sd.py index a7887a82..c5314da7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -376,7 +376,10 @@ class ModelPatcher: mat3 = v[3].float().to(weight.device) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + try: + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) elif len(v) == 8: #lokr w1 = v[0] w2 = v[1] @@ -407,7 +410,10 @@ class ModelPatcher: if v[2] is not None and dim is not None: alpha *= v[2] / dim - weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + try: + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) else: #loha w1a = v[0] w1b = v[1] @@ -424,7 +430,11 @@ class ModelPatcher: m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + try: + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + return weight def unpatch_model(self):