|
|
|
@ -58,6 +58,7 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
self.weight_inplace_update = weight_inplace_update |
|
|
|
|
self.model_lowvram = False |
|
|
|
|
self.lowvram_patch_counter = 0 |
|
|
|
|
self.patches_uuid = uuid.uuid4() |
|
|
|
|
|
|
|
|
|
def model_size(self): |
|
|
|
@ -284,6 +285,7 @@ class ModelPatcher:
|
|
|
|
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) |
|
|
|
|
|
|
|
|
|
mem_counter = 0 |
|
|
|
|
patch_counter = 0 |
|
|
|
|
for n, m in self.model.named_modules(): |
|
|
|
|
lowvram_weight = False |
|
|
|
|
if hasattr(m, "comfy_cast_weights"): |
|
|
|
@ -300,11 +302,13 @@ class ModelPatcher:
|
|
|
|
|
self.patch_weight_to_device(weight_key) |
|
|
|
|
else: |
|
|
|
|
m.weight_function = LowVramPatch(weight_key, self) |
|
|
|
|
patch_counter += 1 |
|
|
|
|
if bias_key in self.patches: |
|
|
|
|
if force_patch_weights: |
|
|
|
|
self.patch_weight_to_device(bias_key) |
|
|
|
|
else: |
|
|
|
|
m.bias_function = LowVramPatch(bias_key, self) |
|
|
|
|
patch_counter += 1 |
|
|
|
|
|
|
|
|
|
m.prev_comfy_cast_weights = m.comfy_cast_weights |
|
|
|
|
m.comfy_cast_weights = True |
|
|
|
@ -317,6 +321,7 @@ class ModelPatcher:
|
|
|
|
|
logging.debug("lowvram: loaded module regularly {}".format(m)) |
|
|
|
|
|
|
|
|
|
self.model_lowvram = True |
|
|
|
|
self.lowvram_patch_counter = patch_counter |
|
|
|
|
return self.model |
|
|
|
|
|
|
|
|
|
def calculate_weight(self, patches, weight, key): |
|
|
|
@ -468,6 +473,7 @@ class ModelPatcher:
|
|
|
|
|
m.bias_function = None |
|
|
|
|
|
|
|
|
|
self.model_lowvram = False |
|
|
|
|
self.lowvram_patch_counter = 0 |
|
|
|
|
|
|
|
|
|
keys = list(self.backup.keys()) |
|
|
|
|
|
|
|
|
|