|
|
|
@ -24,6 +24,7 @@ class ModelPatcher:
|
|
|
|
|
self.current_device = current_device |
|
|
|
|
|
|
|
|
|
self.weight_inplace_update = weight_inplace_update |
|
|
|
|
self.model_lowvram = False |
|
|
|
|
|
|
|
|
|
def model_size(self): |
|
|
|
|
if self.size > 0: |
|
|
|
@ -178,6 +179,27 @@ class ModelPatcher:
|
|
|
|
|
sd.pop(k) |
|
|
|
|
return sd |
|
|
|
|
|
|
|
|
|
def patch_weight_to_device(self, key, device_to=None): |
|
|
|
|
if key not in self.patches: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
weight = comfy.utils.get_attr(self.model, key) |
|
|
|
|
|
|
|
|
|
inplace_update = self.weight_inplace_update |
|
|
|
|
|
|
|
|
|
if key not in self.backup: |
|
|
|
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) |
|
|
|
|
|
|
|
|
|
if device_to is not None: |
|
|
|
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) |
|
|
|
|
else: |
|
|
|
|
temp_weight = weight.to(torch.float32, copy=True) |
|
|
|
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) |
|
|
|
|
if inplace_update: |
|
|
|
|
comfy.utils.copy_to_param(self.model, key, out_weight) |
|
|
|
|
else: |
|
|
|
|
comfy.utils.set_attr_param(self.model, key, out_weight) |
|
|
|
|
|
|
|
|
|
def patch_model(self, device_to=None, patch_weights=True): |
|
|
|
|
for k in self.object_patches: |
|
|
|
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) |
|
|
|
@ -191,23 +213,7 @@ class ModelPatcher:
|
|
|
|
|
logging.warning("could not patch. key doesn't exist in model: {}".format(key)) |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
weight = model_sd[key] |
|
|
|
|
|
|
|
|
|
inplace_update = self.weight_inplace_update |
|
|
|
|
|
|
|
|
|
if key not in self.backup: |
|
|
|
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) |
|
|
|
|
|
|
|
|
|
if device_to is not None: |
|
|
|
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) |
|
|
|
|
else: |
|
|
|
|
temp_weight = weight.to(torch.float32, copy=True) |
|
|
|
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) |
|
|
|
|
if inplace_update: |
|
|
|
|
comfy.utils.copy_to_param(self.model, key, out_weight) |
|
|
|
|
else: |
|
|
|
|
comfy.utils.set_attr_param(self.model, key, out_weight) |
|
|
|
|
del temp_weight |
|
|
|
|
self.patch_weight_to_device(key, device_to) |
|
|
|
|
|
|
|
|
|
if device_to is not None: |
|
|
|
|
self.model.to(device_to) |
|
|
|
@ -215,6 +221,47 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
return self.model |
|
|
|
|
|
|
|
|
|
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): |
|
|
|
|
self.patch_model(device_to, patch_weights=False) |
|
|
|
|
|
|
|
|
|
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) |
|
|
|
|
class LowVramPatch: |
|
|
|
|
def __init__(self, key, model_patcher): |
|
|
|
|
self.key = key |
|
|
|
|
self.model_patcher = model_patcher |
|
|
|
|
def __call__(self, weight): |
|
|
|
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) |
|
|
|
|
|
|
|
|
|
mem_counter = 0 |
|
|
|
|
for n, m in self.model.named_modules(): |
|
|
|
|
lowvram_weight = False |
|
|
|
|
if hasattr(m, "comfy_cast_weights"): |
|
|
|
|
module_mem = comfy.model_management.module_size(m) |
|
|
|
|
if mem_counter + module_mem >= lowvram_model_memory: |
|
|
|
|
lowvram_weight = True |
|
|
|
|
|
|
|
|
|
weight_key = "{}.weight".format(n) |
|
|
|
|
bias_key = "{}.bias".format(n) |
|
|
|
|
|
|
|
|
|
if lowvram_weight: |
|
|
|
|
if weight_key in self.patches: |
|
|
|
|
m.weight_function = LowVramPatch(weight_key, self) |
|
|
|
|
if bias_key in self.patches: |
|
|
|
|
m.bias_function = LowVramPatch(weight_key, self) |
|
|
|
|
|
|
|
|
|
m.prev_comfy_cast_weights = m.comfy_cast_weights |
|
|
|
|
m.comfy_cast_weights = True |
|
|
|
|
else: |
|
|
|
|
if hasattr(m, "weight"): |
|
|
|
|
self.patch_weight_to_device(weight_key, device_to) |
|
|
|
|
self.patch_weight_to_device(bias_key, device_to) |
|
|
|
|
m.to(device_to) |
|
|
|
|
mem_counter += comfy.model_management.module_size(m) |
|
|
|
|
logging.debug("lowvram: loaded module regularly {}".format(m)) |
|
|
|
|
|
|
|
|
|
self.model_lowvram = True |
|
|
|
|
return self.model |
|
|
|
|
|
|
|
|
|
def calculate_weight(self, patches, weight, key): |
|
|
|
|
for p in patches: |
|
|
|
|
alpha = p[0] |
|
|
|
@ -341,6 +388,16 @@ class ModelPatcher:
|
|
|
|
|
return weight |
|
|
|
|
|
|
|
|
|
def unpatch_model(self, device_to=None): |
|
|
|
|
if self.model_lowvram: |
|
|
|
|
for m in self.model.modules(): |
|
|
|
|
if hasattr(m, "prev_comfy_cast_weights"): |
|
|
|
|
m.comfy_cast_weights = m.prev_comfy_cast_weights |
|
|
|
|
del m.prev_comfy_cast_weights |
|
|
|
|
m.weight_function = None |
|
|
|
|
m.bias_function = None |
|
|
|
|
|
|
|
|
|
self.model_lowvram = False |
|
|
|
|
|
|
|
|
|
keys = list(self.backup.keys()) |
|
|
|
|
|
|
|
|
|
if self.weight_inplace_update: |
|
|
|
|