Browse Source

Merge branch 'master' of https://github.com/BlenderNeko/ComfyUI

pull/559/head
BlenderNeko 2 years ago
parent
commit
8d2de420d3
  1. 80
      comfy/ldm/modules/attention.py
  2. 3
      comfy/model_management.py
  3. 10
      comfy/samplers.py
  4. 23
      comfy/sd.py
  5. 7
      comfy/utils.py
  6. 87
      comfy_extras/nodes_hypernetwork.py
  7. 1
      folder_paths.py
  8. 0
      models/hypernetworks/put_hypernetworks_here
  9. 1
      nodes.py
  10. 4
      web/scripts/api.js

80
comfy/ldm/modules/attention.py

@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
h = self.heads h = self.heads
query = self.to_q(x) query = self.to_q(x)
context = default(context, x) context = default(context, x)
key = self.to_k(context) key = self.to_k(context)
value = self.to_v(context) if value is not None:
value = self.to_v(value)
else:
value = self.to_v(context)
del context, x del context, x
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
k_in = self.to_k(context) k_in = self.to_k(context)
v_in = self.to_v(context) if value is not None:
v_in = self.to_v(value)
del value
else:
v_in = self.to_v(context)
del context, x del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
@ -350,13 +358,17 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout) nn.Dropout(dropout)
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
b, _, _ = q.shape b, _, _ = q.shape
q, k, v = map( q, k, v = map(
@ -447,19 +463,19 @@ class CrossAttentionPytorch(nn.Module):
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
b, _, _ = q.shape b, _, _ = q.shape
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v), (q, k, v),
) )
@ -468,10 +484,7 @@ class CrossAttentionPytorch(nn.Module):
if exists(mask): if exists(mask):
raise NotImplementedError raise NotImplementedError
out = ( out = (
out.unsqueeze(0) out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
) )
return self.to_out(out) return self.to_out(out)
@ -519,11 +532,25 @@ class BasicTransformerBlock(nn.Module):
transformer_patches = {} transformer_patches = {}
n = self.norm1(x) n = self.norm1(x)
if self.disable_self_attn:
context_attn1 = context
else:
context_attn1 = None
value_attn1 = None
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
if context_attn1 is None:
context_attn1 = n
value_attn1 = context_attn1
for p in patch:
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
if "tomesd" in transformer_options: if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
else: else:
n = self.attn1(n, context=context if self.disable_self_attn else None) n = self.attn1(n, context=context_attn1, value=value_attn1)
x += n x += n
if "middle_patch" in transformer_patches: if "middle_patch" in transformer_patches:
@ -532,7 +559,16 @@ class BasicTransformerBlock(nn.Module):
x = p(current_index, x) x = p(current_index, x)
n = self.norm2(x) n = self.norm2(x)
n = self.attn2(n, context=context)
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
for p in patch:
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
n = self.attn2(n, context=context_attn2, value=value_attn2)
x += n x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x

3
comfy/model_management.py

@ -133,6 +133,7 @@ def unload_model():
#never unload models from GPU on high vram #never unload models from GPU on high vram
if vram_state != VRAMState.HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
current_loaded_model.model.cpu() current_loaded_model.model.cpu()
current_loaded_model.model_patches_to("cpu")
current_loaded_model.unpatch_model() current_loaded_model.unpatch_model()
current_loaded_model = None current_loaded_model = None
@ -156,6 +157,8 @@ def load_model_gpu(model):
except Exception as e: except Exception as e:
model.unpatch_model() model.unpatch_model()
raise e raise e
model.model_patches_to(get_torch_device())
current_loaded_model = model current_loaded_model = model
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
pass pass

10
comfy/samplers.py

@ -197,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
transformer_options = model_options['transformer_options'].copy() transformer_options = model_options['transformer_options'].copy()
if patches is not None: if patches is not None:
transformer_options["patches"] = patches if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
c['transformer_options'] = transformer_options c['transformer_options'] = transformer_options

23
comfy/sd.py

@ -254,6 +254,29 @@ class ModelPatcher:
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function):
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
def model_dtype(self): def model_dtype(self):
return self.model.diffusion_model.dtype return self.model.diffusion_model.dtype

7
comfy/utils.py

@ -1,11 +1,14 @@
import torch import torch
def load_torch_file(ckpt): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
import safetensors.torch import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu") sd = safetensors.torch.load_file(ckpt, device="cpu")
else: else:
pl_sd = torch.load(ckpt, map_location="cpu") if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: if "state_dict" in pl_sd:

87
comfy_extras/nodes_hypernetwork.py

@ -0,0 +1,87 @@
import comfy.utils
import folder_paths
import torch
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
activation_func = sd.get('activation_func', 'linear')
is_layer_norm = sd.get('is_layer_norm', False)
use_dropout = sd.get('use_dropout', False)
activate_output = sd.get('activate_output', False)
last_layer_dropout = sd.get('last_layer_dropout', False)
if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
return None
out = {}
for d in sd:
try:
dim = int(d)
except:
continue
output = []
for index in [0, 1]:
attn_weights = sd[dim][index]
keys = attn_weights.keys()
linears = filter(lambda a: a.endswith(".weight"), keys)
linears = sorted(list(map(lambda a: a[:-len(".weight")], linears)))
layers = []
for lin_name in linears:
lin_weight = attn_weights['{}.weight'.format(lin_name)]
lin_bias = attn_weights['{}.bias'.format(lin_name)]
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
layers += [layer]
output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output)
class hypernetwork_patch:
def __init__(self, hypernet, strength):
self.hypernet = hypernet
self.strength = strength
def __call__(self, current_index, q, k, v):
dim = k.shape[-1]
if dim in self.hypernet:
hn = self.hypernet[dim]
k = k + hn[0](k) * self.strength
v = v + hn[1](v) * self.strength
return q, k, v
def to(self, device):
for d in self.hypernet.keys():
self.hypernet[d] = self.hypernet[d].to(device)
return self
return hypernetwork_patch(out, strength)
class HypernetworkLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork"
CATEGORY = "_for_testing"
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch)
return (model_hypernetwork,)
NODE_CLASS_MAPPINGS = {
"HypernetworkLoader": HypernetworkLoader
}

1
folder_paths.py

@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")

0
models/hypernetworks/put_hypernetworks_here

1
nodes.py

@ -1226,6 +1226,7 @@ def load_custom_nodes():
def init_custom_nodes(): def init_custom_nodes():
load_custom_nodes() load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))

4
web/scripts/api.js

@ -35,7 +35,7 @@ class ComfyApi extends EventTarget {
} }
let opened = false; let opened = false;
let existingSession = sessionStorage["Comfy.SessionId"] || ""; let existingSession = window.name;
if (existingSession) { if (existingSession) {
existingSession = "?clientId=" + existingSession; existingSession = "?clientId=" + existingSession;
} }
@ -75,7 +75,7 @@ class ComfyApi extends EventTarget {
case "status": case "status":
if (msg.data.sid) { if (msg.data.sid) {
this.clientId = msg.data.sid; this.clientId = msg.data.sid;
sessionStorage["Comfy.SessionId"] = this.clientId; window.name = this.clientId;
} }
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break; break;

Loading…
Cancel
Save