diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c27d032a..ce7180d9 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads query = self.to_q(x) context = default(context, x) key = self.to_k(context) - value = self.to_v(context) + if value is not None: + value = self.to_v(value) + else: + value = self.to_v(context) + del context, x query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) @@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) - v_in = self.to_v(context) + if value is not None: + v_in = self.to_v(value) + del value + else: + v_in = self.to_v(context) del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) @@ -350,13 +358,17 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) @@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -447,11 +463,15 @@ class CrossAttentionPytorch(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module): transformer_patches = {} n = self.norm1(x) + if self.disable_self_attn: + context_attn1 = context + else: + context_attn1 = None + value_attn1 = None + + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + if context_attn1 is None: + context_attn1 = n + value_attn1 = context_attn1 + for p in patch: + n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) + if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) - n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) + n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) else: - n = self.attn1(n, context=context if self.disable_self_attn else None) + n = self.attn1(n, context=context_attn1, value=value_attn1) x += n if "middle_patch" in transformer_patches: @@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module): x = p(current_index, x) n = self.norm2(x) - n = self.attn2(n, context=context) + + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] + value_attn2 = context_attn2 + for p in patch: + n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) + + n = self.attn2(n, context=context_attn2, value=value_attn2) x += n x = self.ff(self.norm3(x)) + x diff --git a/comfy/model_management.py b/comfy/model_management.py index a0d1313d..6e3a0353 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -133,6 +133,7 @@ def unload_model(): #never unload models from GPU on high vram if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() + current_loaded_model.model_patches_to("cpu") current_loaded_model.unpatch_model() current_loaded_model = None @@ -156,6 +157,8 @@ def load_model_gpu(model): except Exception as e: model.unpatch_model() raise e + + model.model_patches_to(get_torch_device()) current_loaded_model = model if vram_state == VRAMState.CPU: pass diff --git a/comfy/samplers.py b/comfy/samplers.py index 15527224..b860f25f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -197,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con transformer_options = model_options['transformer_options'].copy() if patches is not None: - transformer_options["patches"] = patches + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches c['transformer_options'] = transformer_options diff --git a/comfy/sd.py b/comfy/sd.py index 211acd70..92dbb931 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -254,6 +254,29 @@ class ModelPatcher: def set_model_sampler_cfg_function(self, sampler_cfg_function): self.model_options["sampler_cfg_function"] = sampler_cfg_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + def model_dtype(self): return self.model.diffusion_model.dtype diff --git a/comfy/utils.py b/comfy/utils.py index 0380b91d..68f93403 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,11 +1,14 @@ import torch -def load_torch_file(ckpt): +def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: - pl_sd = torch.load(ckpt, map_location="cpu") + if safe_load: + pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) + else: + pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py new file mode 100644 index 00000000..db2f8695 --- /dev/null +++ b/comfy_extras/nodes_hypernetwork.py @@ -0,0 +1,87 @@ +import comfy.utils +import folder_paths +import torch + +def load_hypernetwork_patch(path, strength): + sd = comfy.utils.load_torch_file(path, safe_load=True) + activation_func = sd.get('activation_func', 'linear') + is_layer_norm = sd.get('is_layer_norm', False) + use_dropout = sd.get('use_dropout', False) + activate_output = sd.get('activate_output', False) + last_layer_dropout = sd.get('last_layer_dropout', False) + + if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False: + print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) + return None + + out = {} + + for d in sd: + try: + dim = int(d) + except: + continue + + output = [] + for index in [0, 1]: + attn_weights = sd[dim][index] + keys = attn_weights.keys() + + linears = filter(lambda a: a.endswith(".weight"), keys) + linears = sorted(list(map(lambda a: a[:-len(".weight")], linears))) + layers = [] + + for lin_name in linears: + lin_weight = attn_weights['{}.weight'.format(lin_name)] + lin_bias = attn_weights['{}.bias'.format(lin_name)] + layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) + layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) + layers += [layer] + + output.append(torch.nn.Sequential(*layers)) + out[dim] = torch.nn.ModuleList(output) + + class hypernetwork_patch: + def __init__(self, hypernet, strength): + self.hypernet = hypernet + self.strength = strength + def __call__(self, current_index, q, k, v): + dim = k.shape[-1] + if dim in self.hypernet: + hn = self.hypernet[dim] + k = k + hn[0](k) * self.strength + v = v + hn[1](v) * self.strength + + return q, k, v + + def to(self, device): + for d in self.hypernet.keys(): + self.hypernet[d] = self.hypernet[d].to(device) + return self + + return hypernetwork_patch(out, strength) + +class HypernetworkLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_hypernetwork" + + CATEGORY = "_for_testing" + + def load_hypernetwork(self, model, hypernetwork_name, strength): + hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) + model_hypernetwork = model.clone() + patch = load_hypernetwork_patch(hypernetwork_path, strength) + if patch is not None: + model_hypernetwork.set_model_attn1_patch(patch) + model_hypernetwork.set_model_attn2_patch(patch) + return (model_hypernetwork,) + +NODE_CLASS_MAPPINGS = { + "HypernetworkLoader": HypernetworkLoader +} diff --git a/folder_paths.py b/folder_paths.py index 3c4ad371..bb0d6552 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) +folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") diff --git a/models/hypernetworks/put_hypernetworks_here b/models/hypernetworks/put_hypernetworks_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index 48c3ee9c..6ca73fa0 100644 --- a/nodes.py +++ b/nodes.py @@ -1268,6 +1268,7 @@ def load_custom_nodes(): def init_custom_nodes(): load_custom_nodes() + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))