Browse Source

Merge remote-tracking branch 'origin/master' into custom_routes

pull/354/head
pythongosssss 2 years ago
parent
commit
38833ceb62
  1. 14
      comfy/ldm/models/diffusion/ddim.py
  2. 18
      comfy/ldm/models/diffusion/ddpm.py
  3. 25
      comfy/ldm/modules/attention.py
  4. 13
      comfy/ldm/modules/diffusionmodules/openaimodel.py
  5. 117
      comfy/ldm/modules/tomesd.py
  6. 24
      comfy/samplers.py
  7. 9
      comfy/sd.py
  8. 19
      nodes.py
  9. 2
      web/extensions/core/widgetInputs.js

14
comfy/ldm/models/diffusion/ddim.py

@ -78,7 +78,7 @@ class DDIMSampler(object):
dynamic_threshold=None, dynamic_threshold=None,
ucg_schedule=None, ucg_schedule=None,
denoise_function=None, denoise_function=None,
cond_concat=None, extra_args=None,
to_zero=True, to_zero=True,
end_step=None, end_step=None,
**kwargs **kwargs
@ -101,7 +101,7 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold, dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule, ucg_schedule=ucg_schedule,
denoise_function=denoise_function, denoise_function=denoise_function,
cond_concat=cond_concat, extra_args=extra_args,
to_zero=to_zero, to_zero=to_zero,
end_step=end_step end_step=end_step
) )
@ -174,7 +174,7 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold, dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule, ucg_schedule=ucg_schedule,
denoise_function=None, denoise_function=None,
cond_concat=None extra_args=None
) )
return samples, intermediates return samples, intermediates
@ -185,7 +185,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None): ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -225,7 +225,7 @@ class DDIMSampler(object):
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat) dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
img, pred_x0 = outs img, pred_x0 = outs
if callback: callback(i) if callback: callback(i)
if img_callback: img_callback(pred_x0, i) if img_callback: img_callback(pred_x0, i)
@ -249,11 +249,11 @@ class DDIMSampler(object):
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None, denoise_function=None, cond_concat=None): dynamic_threshold=None, denoise_function=None, extra_args=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if denoise_function is not None: if denoise_function is not None:
model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat) model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c) model_output = self.model.apply_model(x, t, c)
else: else:

18
comfy/ldm/models/diffusion/ddpm.py

@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
self.conditioning_key = conditioning_key self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t, control=control) out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'concat': elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, control=control) out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'crossattn': elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn: if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments # TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce # with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'. # an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc, control=control) out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
else: else:
out = self.diffusion_model(x, t, context=cc, control=control) out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, control=control) out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid-adm': elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'crossattn-adm': elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None assert c_adm is not None
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'adm': elif self.conditioning_key == 'adm':
cc = c_crossattn[0] cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc, control=control) out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options)
else: else:
raise NotImplementedError() raise NotImplementedError()

25
comfy/ldm/modules/attention.py

@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
import model_management import model_management
from . import tomesd
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers
@ -504,12 +505,22 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None): def _forward(self, x, context=None, transformer_options={}):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x n = self.norm1(x)
x = self.attn2(self.norm2(x), context=context) + x if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
else:
n = self.attn1(n, context=context if self.disable_self_attn else None)
x += n
n = self.norm2(x)
n = self.attn2(n, context=context)
x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
return x return x
@ -557,7 +568,7 @@ class SpatialTransformer(nn.Module):
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None): def forward(self, x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list): if not isinstance(context, list):
context = [context] context = [context]
@ -570,7 +581,7 @@ class SpatialTransformer(nn.Module):
if self.use_linear: if self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i]) x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear: if self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()

13
comfy/ldm/modules/diffusionmodules/openaimodel.py

@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None): def forward(self, x, emb, context=None, transformer_options={}):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context) x = layer(x, context, transformer_options)
else: else:
x = layer(x) x = layer(x)
return x return x
@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.
@ -762,6 +762,7 @@ class UNetModel(nn.Module):
:param y: an [N] Tensor of labels, if class-conditional. :param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
transformer_options["original_shape"] = list(x.shape)
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
@ -775,13 +776,13 @@ class UNetModel(nn.Module):
h = x.type(self.dtype) h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
h = module(h, emb, context) h = module(h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0: if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop() ctrl = control['input'].pop()
if ctrl is not None: if ctrl is not None:
h += ctrl h += ctrl
hs.append(h) hs.append(h)
h = self.middle_block(h, emb, context) h = self.middle_block(h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0: if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop() h += control['middle'].pop()
@ -793,7 +794,7 @@ class UNetModel(nn.Module):
hsp += ctrl hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context) h = module(h, emb, context, transformer_options)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)

117
comfy/ldm/modules/tomesd.py

@ -0,0 +1,117 @@
import torch
from typing import Tuple, Callable
import math
def do_nothing(x: torch.Tensor, mode:str=None):
return x
def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
- h: image height in tokens
- sx: stride in the x dimension for dst, must divide w
- sy: stride in the y dimension for dst, must divide h
- r: number of tokens to remove (by merging)
- no_rand: if true, disable randomness (use top left corner only)
"""
B, N, _ = metric.shape
if r <= 0:
return do_nothing, do_nothing
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
if no_rand:
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device)
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype))
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1)
rand_idx = idx_buffer.argsort(dim=1)
num_dst = int((1 / (sx*sy)) * N)
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
def split(x):
C = x.shape[-1]
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1)
def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src)
return out
return merge, unmerge
def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w
downsample = int(math.sqrt(original_tokens // x.shape[1]))
stride_x = 2
stride_y = 2
max_downsample = 1
if downsample <= max_downsample:
w = original_w // downsample
h = original_h // downsample
r = int(x.shape[1] * ratio)
no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
return m, u
nothing = lambda y: y
return nothing, nothing

24
comfy/samplers.py

@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module):
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns predicted noise #Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
out['c_concat'] = [torch.cat(c_concat)] out['c_concat'] = [torch.cat(c_concat)]
return out return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in): def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0 out_count = torch.ones_like(x_in)/100000.0
@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
if 'transformer_options' in model_options:
c['transformer_options'] = model_options['transformer_options']
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x del input_x
@ -192,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area() max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
@ -209,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None): def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat) out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
return out return out
@ -218,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
if denoise_mask is not None: if denoise_mask is not None:
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat) out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
if denoise_mask is not None: if denoise_mask is not None:
out *= denoise_mask out *= denoise_mask
@ -330,7 +333,7 @@ class KSampler:
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model self.model = model
self.model_denoise = CFGNoisePredictor(self.model) self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v": if self.model.parameterization == "v":
@ -350,6 +353,7 @@ class KSampler:
self.sigma_max=float(self.model_wrap.sigma_max) self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise) self.set_steps(steps, denoise)
self.denoise = denoise self.denoise = denoise
self.model_options = model_options
def _calculate_sigmas(self, steps): def _calculate_sigmas(self, steps):
sigmas = None sigmas = None
@ -418,7 +422,7 @@ class KSampler:
else: else:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None cond_concat = None
if hasattr(self.model, 'concat_keys'): if hasattr(self.model, 'concat_keys'):
@ -467,7 +471,7 @@ class KSampler:
x_T=z_enc, x_T=z_enc,
x0=latent_image, x0=latent_image,
denoise_function=sampling_function, denoise_function=sampling_function,
cond_concat=cond_concat, extra_args=extra_args,
mask=noise_mask, mask=noise_mask,
to_zero=sigmas[-1]==0, to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1) end_step=sigmas.shape[0] - 1)

9
comfy/sd.py

@ -1,5 +1,6 @@
import torch import torch
import contextlib import contextlib
import copy
import sd1_clip import sd1_clip
import sd2_clip import sd2_clip
@ -274,12 +275,20 @@ class ModelPatcher:
self.model = model self.model = model
self.patches = [] self.patches = []
self.backup = {} self.backup = {}
self.model_options = {"transformer_options":{}}
def clone(self): def clone(self):
n = ModelPatcher(self.model) n = ModelPatcher(self.model)
n.patches = self.patches[:] n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options)
return n return n
def set_model_tomesd(self, ratio):
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
def model_dtype(self):
return self.model.diffusion_model.dtype
def add_patches(self, patches, strength=1.0): def add_patches(self, patches, strength=1.0):
p = {} p = {}
model_sd = self.model.state_dict() model_sd = self.model.state_dict()

19
nodes.py

@ -254,6 +254,22 @@ class LoraLoader:
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class TomePatchModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, ratio):
m = model.clone()
m.set_model_tomesd(ratio)
return (m, )
class VAELoader: class VAELoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
model_management.load_controlnet_gpu(control_net_models) model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
else: else:
#other samplers #other samplers
pass pass
@ -1016,6 +1032,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPVisionLoader": CLIPVisionLoader, "CLIPVisionLoader": CLIPVisionLoader,
"VAEDecodeTiled": VAEDecodeTiled, "VAEDecodeTiled": VAEDecodeTiled,
"VAEEncodeTiled": VAEEncodeTiled, "VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel,
} }
def load_custom_node(module_path): def load_custom_node(module_path):

2
web/extensions/core/widgetInputs.js

@ -101,7 +101,7 @@ app.registerExtension({
callback: () => convertToWidget(this, w), callback: () => convertToWidget(this, w),
}); });
} else { } else {
const config = nodeData?.input?.required[w.name] || [w.type, w.options || {}]; const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}];
if (isConvertableWidget(w, config)) { if (isConvertableWidget(w, config)) {
toInput.push({ toInput.push({
content: `Convert ${w.name} to input`, content: `Convert ${w.name} to input`,

Loading…
Cancel
Save