|
|
@ -89,8 +89,7 @@ LORA_UNET_MAP_RESNET = { |
|
|
|
"skip_connection": "resnets_{}_conv_shortcut" |
|
|
|
"skip_connection": "resnets_{}_conv_shortcut" |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def load_lora(path, to_load): |
|
|
|
def load_lora(lora, to_load): |
|
|
|
lora = utils.load_torch_file(path, safe_load=True) |
|
|
|
|
|
|
|
patch_dict = {} |
|
|
|
patch_dict = {} |
|
|
|
loaded_keys = set() |
|
|
|
loaded_keys = set() |
|
|
|
for x in to_load: |
|
|
|
for x in to_load: |
|
|
@ -501,10 +500,10 @@ class ModelPatcher: |
|
|
|
|
|
|
|
|
|
|
|
self.backup = {} |
|
|
|
self.backup = {} |
|
|
|
|
|
|
|
|
|
|
|
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip): |
|
|
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): |
|
|
|
key_map = model_lora_keys(model.model) |
|
|
|
key_map = model_lora_keys(model.model) |
|
|
|
key_map = model_lora_keys(clip.cond_stage_model, key_map) |
|
|
|
key_map = model_lora_keys(clip.cond_stage_model, key_map) |
|
|
|
loaded = load_lora(lora_path, key_map) |
|
|
|
loaded = load_lora(lora, key_map) |
|
|
|
new_modelpatcher = model.clone() |
|
|
|
new_modelpatcher = model.clone() |
|
|
|
k = new_modelpatcher.add_patches(loaded, strength_model) |
|
|
|
k = new_modelpatcher.add_patches(loaded, strength_model) |
|
|
|
new_clip = clip.clone() |
|
|
|
new_clip = clip.clone() |
|
|
|