|
|
|
@ -340,7 +340,7 @@ class ModelPatcher:
|
|
|
|
|
weight = model_sd[key] |
|
|
|
|
|
|
|
|
|
if key not in self.backup: |
|
|
|
|
self.backup[key] = weight.clone() |
|
|
|
|
self.backup[key] = weight.to(self.offload_device, copy=True) |
|
|
|
|
|
|
|
|
|
temp_weight = weight.to(torch.float32, copy=True) |
|
|
|
|
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) |
|
|
|
@ -367,15 +367,16 @@ class ModelPatcher:
|
|
|
|
|
else: |
|
|
|
|
weight += alpha * w1.type(weight.dtype).to(weight.device) |
|
|
|
|
elif len(v) == 4: #lora/locon |
|
|
|
|
mat1 = v[0] |
|
|
|
|
mat2 = v[1] |
|
|
|
|
mat1 = v[0].float().to(weight.device) |
|
|
|
|
mat2 = v[1].float().to(weight.device) |
|
|
|
|
if v[2] is not None: |
|
|
|
|
alpha *= v[2] / mat2.shape[0] |
|
|
|
|
if v[3] is not None: |
|
|
|
|
#locon mid weights, hopefully the math is fine because I didn't properly test it |
|
|
|
|
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] |
|
|
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) |
|
|
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) |
|
|
|
|
mat3 = v[3].float().to(weight.device) |
|
|
|
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] |
|
|
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) |
|
|
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) |
|
|
|
|
elif len(v) == 8: #lokr |
|
|
|
|
w1 = v[0] |
|
|
|
|
w2 = v[1] |
|
|
|
@ -389,20 +390,24 @@ class ModelPatcher:
|
|
|
|
|
if w1 is None: |
|
|
|
|
dim = w1_b.shape[0] |
|
|
|
|
w1 = torch.mm(w1_a.float(), w1_b.float()) |
|
|
|
|
else: |
|
|
|
|
w1 = w1.float().to(weight.device) |
|
|
|
|
|
|
|
|
|
if w2 is None: |
|
|
|
|
dim = w2_b.shape[0] |
|
|
|
|
if t2 is None: |
|
|
|
|
w2 = torch.mm(w2_a.float(), w2_b.float()) |
|
|
|
|
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) |
|
|
|
|
else: |
|
|
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float()) |
|
|
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) |
|
|
|
|
else: |
|
|
|
|
w2 = w2.float().to(weight.device) |
|
|
|
|
|
|
|
|
|
if len(w2.shape) == 4: |
|
|
|
|
w1 = w1.unsqueeze(2).unsqueeze(2) |
|
|
|
|
if v[2] is not None and dim is not None: |
|
|
|
|
alpha *= v[2] / dim |
|
|
|
|
|
|
|
|
|
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device) |
|
|
|
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) |
|
|
|
|
else: #loha |
|
|
|
|
w1a = v[0] |
|
|
|
|
w1b = v[1] |
|
|
|
@ -413,13 +418,13 @@ class ModelPatcher:
|
|
|
|
|
if v[5] is not None: #cp decomposition |
|
|
|
|
t1 = v[5] |
|
|
|
|
t2 = v[6] |
|
|
|
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float()) |
|
|
|
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float()) |
|
|
|
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) |
|
|
|
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) |
|
|
|
|
else: |
|
|
|
|
m1 = torch.mm(w1a.float(), w1b.float()) |
|
|
|
|
m2 = torch.mm(w2a.float(), w2b.float()) |
|
|
|
|
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) |
|
|
|
|
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) |
|
|
|
|
|
|
|
|
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) |
|
|
|
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) |
|
|
|
|
return weight |
|
|
|
|
|
|
|
|
|
def unpatch_model(self): |
|
|
|
|