|
|
|
@ -174,40 +174,41 @@ class ModelPatcher:
|
|
|
|
|
sd.pop(k) |
|
|
|
|
return sd |
|
|
|
|
|
|
|
|
|
def patch_model(self, device_to=None): |
|
|
|
|
def patch_model(self, device_to=None, patch_weights=True): |
|
|
|
|
for k in self.object_patches: |
|
|
|
|
old = getattr(self.model, k) |
|
|
|
|
if k not in self.object_patches_backup: |
|
|
|
|
self.object_patches_backup[k] = old |
|
|
|
|
setattr(self.model, k, self.object_patches[k]) |
|
|
|
|
|
|
|
|
|
model_sd = self.model_state_dict() |
|
|
|
|
for key in self.patches: |
|
|
|
|
if key not in model_sd: |
|
|
|
|
print("could not patch. key doesn't exist in model:", key) |
|
|
|
|
continue |
|
|
|
|
if patch_weights: |
|
|
|
|
model_sd = self.model_state_dict() |
|
|
|
|
for key in self.patches: |
|
|
|
|
if key not in model_sd: |
|
|
|
|
print("could not patch. key doesn't exist in model:", key) |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
weight = model_sd[key] |
|
|
|
|
weight = model_sd[key] |
|
|
|
|
|
|
|
|
|
inplace_update = self.weight_inplace_update |
|
|
|
|
inplace_update = self.weight_inplace_update |
|
|
|
|
|
|
|
|
|
if key not in self.backup: |
|
|
|
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) |
|
|
|
|
if key not in self.backup: |
|
|
|
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) |
|
|
|
|
|
|
|
|
|
if device_to is not None: |
|
|
|
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) |
|
|
|
|
else: |
|
|
|
|
temp_weight = weight.to(torch.float32, copy=True) |
|
|
|
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) |
|
|
|
|
if inplace_update: |
|
|
|
|
comfy.utils.copy_to_param(self.model, key, out_weight) |
|
|
|
|
else: |
|
|
|
|
comfy.utils.set_attr(self.model, key, out_weight) |
|
|
|
|
del temp_weight |
|
|
|
|
if device_to is not None: |
|
|
|
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) |
|
|
|
|
else: |
|
|
|
|
temp_weight = weight.to(torch.float32, copy=True) |
|
|
|
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) |
|
|
|
|
if inplace_update: |
|
|
|
|
comfy.utils.copy_to_param(self.model, key, out_weight) |
|
|
|
|
else: |
|
|
|
|
comfy.utils.set_attr(self.model, key, out_weight) |
|
|
|
|
del temp_weight |
|
|
|
|
|
|
|
|
|
if device_to is not None: |
|
|
|
|
self.model.to(device_to) |
|
|
|
|
self.current_device = device_to |
|
|
|
|
if device_to is not None: |
|
|
|
|
self.model.to(device_to) |
|
|
|
|
self.current_device = device_to |
|
|
|
|
|
|
|
|
|
return self.model |
|
|
|
|
|
|
|
|
|