|
|
|
@ -202,6 +202,14 @@ def model_lora_keys_unet(model, key_map={}):
|
|
|
|
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) |
|
|
|
|
return key_map |
|
|
|
|
|
|
|
|
|
def set_attr(obj, attr, value): |
|
|
|
|
attrs = attr.split(".") |
|
|
|
|
for name in attrs[:-1]: |
|
|
|
|
obj = getattr(obj, name) |
|
|
|
|
prev = getattr(obj, attrs[-1]) |
|
|
|
|
setattr(obj, attrs[-1], torch.nn.Parameter(value)) |
|
|
|
|
del prev |
|
|
|
|
|
|
|
|
|
class ModelPatcher: |
|
|
|
|
def __init__(self, model, load_device, offload_device, size=0): |
|
|
|
|
self.size = size |
|
|
|
@ -340,10 +348,11 @@ class ModelPatcher:
|
|
|
|
|
weight = model_sd[key] |
|
|
|
|
|
|
|
|
|
if key not in self.backup: |
|
|
|
|
self.backup[key] = weight.to(self.offload_device, copy=True) |
|
|
|
|
self.backup[key] = weight.to(self.offload_device) |
|
|
|
|
|
|
|
|
|
temp_weight = weight.to(torch.float32, copy=True) |
|
|
|
|
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) |
|
|
|
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) |
|
|
|
|
set_attr(self.model, key, out_weight) |
|
|
|
|
del temp_weight |
|
|
|
|
return self.model |
|
|
|
|
|
|
|
|
@ -439,13 +448,6 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
def unpatch_model(self): |
|
|
|
|
keys = list(self.backup.keys()) |
|
|
|
|
def set_attr(obj, attr, value): |
|
|
|
|
attrs = attr.split(".") |
|
|
|
|
for name in attrs[:-1]: |
|
|
|
|
obj = getattr(obj, name) |
|
|
|
|
prev = getattr(obj, attrs[-1]) |
|
|
|
|
setattr(obj, attrs[-1], torch.nn.Parameter(value)) |
|
|
|
|
del prev |
|
|
|
|
|
|
|
|
|
for k in keys: |
|
|
|
|
set_attr(self.model, k, self.backup[k]) |
|
|
|
|