Browse Source

Add ControlNet support.

pull/11/head
comfyanonymous 2 years ago
parent
commit
4efa67fa12
  1. 286
      comfy/cldm/cldm.py
  2. 4
      comfy/extra_samplers/uni_pc.py
  3. 18
      comfy/ldm/models/diffusion/ddpm.py
  4. 10
      comfy/ldm/modules/diffusionmodules/openaimodel.py
  5. 17
      comfy/model_management.py
  6. 121
      comfy/samplers.py
  7. 76
      comfy/sd.py
  8. 18
      comfy/utils.py
  9. 93
      nodes.py

286
comfy/cldm/cldm.py

@ -0,0 +1,286 @@
#taken from: https://github.com/lllyasviel/ControlNet
#and modified
import einops
import torch
import torch as th
import torch.nn as nn
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config
class ControlledUnetModel(UNetModel):
#implemented in the ldm unet
pass
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs

4
comfy/extra_samplers/uni_pc.py

@ -856,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
device = noise.device device = noise.device
if model.inner_model.parameterization == "v": if model.parameterization == "v":
model_type = "v" model_type = "v"
else: else:
model_type = "noise" model_type = "noise"
model_fn = model_wrapper( model_fn = model_wrapper(
model.inner_model.apply_model, model.inner_model.inner_model.apply_model,
sampling_function, sampling_function,
ns, ns,
model_type=model_type, model_type=model_type,

18
comfy/ldm/models/diffusion/ddpm.py

@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
self.conditioning_key = conditioning_key self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t) out = self.diffusion_model(x, t, control=control)
elif self.conditioning_key == 'concat': elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t) out = self.diffusion_model(xc, t, control=control)
elif self.conditioning_key == 'crossattn': elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn: if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments # TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce # with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'. # an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc) out = self.scripted_diffusion_model(x, t, cc, control=control)
else: else:
out = self.diffusion_model(x, t, context=cc) out = self.diffusion_model(x, t, context=cc, control=control)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc) out = self.diffusion_model(xc, t, context=cc, control=control)
elif self.conditioning_key == 'hybrid-adm': elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm) out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
elif self.conditioning_key == 'crossattn-adm': elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None assert c_adm is not None
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm) out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
elif self.conditioning_key == 'adm': elif self.conditioning_key == 'adm':
cc = c_crossattn[0] cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc) out = self.diffusion_model(x, t, y=cc, control=control)
else: else:
raise NotImplementedError() raise NotImplementedError()

10
comfy/ldm/modules/diffusionmodules/openaimodel.py

@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs): def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.
@ -778,8 +778,14 @@ class UNetModel(nn.Module):
h = module(h, emb, context) h = module(h, emb, context)
hs.append(h) hs.append(h)
h = self.middle_block(h, emb, context) h = self.middle_block(h, emb, context)
if control is not None:
h += control.pop()
for module in self.output_blocks: for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1) hsp = hs.pop()
if control is not None:
hsp += control.pop()
h = th.cat([h, hsp], dim=1)
h = module(h, emb, context) h = module(h, emb, context)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:

17
comfy/model_management.py

@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
current_loaded_model = None current_loaded_model = None
current_gpu_controlnets = []
model_accelerated = False model_accelerated = False
@ -56,6 +56,7 @@ model_accelerated = False
def unload_model(): def unload_model():
global current_loaded_model global current_loaded_model
global model_accelerated global model_accelerated
global current_gpu_controlnets
if current_loaded_model is not None: if current_loaded_model is not None:
if model_accelerated: if model_accelerated:
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
@ -64,6 +65,10 @@ def unload_model():
current_loaded_model.model.cpu() current_loaded_model.model.cpu()
current_loaded_model.unpatch_model() current_loaded_model.unpatch_model()
current_loaded_model = None current_loaded_model = None
if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets:
n.cpu()
current_gpu_controlnets = []
def load_model_gpu(model): def load_model_gpu(model):
@ -95,6 +100,16 @@ def load_model_gpu(model):
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
def load_controlnet_gpu(models):
global current_gpu_controlnets
for m in current_gpu_controlnets:
if m not in models:
m.cpu()
current_gpu_controlnets = []
for m in models:
current_gpu_controlnets.append(m.cuda())
def get_free_memory(): def get_free_memory():
dev = torch.cuda.current_device() dev = torch.cuda.current_device()

121
comfy/samplers.py

@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
uncond = self.inner_model(x, sigma, cond=uncond) uncond = self.inner_model(x, sigma, cond=uncond)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None):
def get_area_and_mult(cond, x_in, cond_concat_in): #The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
min_sigma = 0.0
max_sigma = 999.0
if 'area' in cond[1]: if 'area' in cond[1]:
area = cond[1]['area'] area = cond[1]['area']
if 'strength' in cond[1]: if 'strength' in cond[1]:
@ -56,9 +57,15 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr) cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1) conditionning['c_concat'] = torch.cat(cropped, dim=1)
return (input_x, mult, conditionning, area)
control = None
if 'control' in cond[1]:
control = cond[1]['control']
return (input_x, mult, conditionning, area, control)
def cond_equal_size(c1, c2): def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys(): if c1.keys() != c2.keys():
return False return False
if 'c_crossattn' in c1: if 'c_crossattn' in c1:
@ -69,6 +76,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
return False return False
return True return True
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list): def cond_cat(c_list):
c_crossattn = [] c_crossattn = []
c_concat = [] c_concat = []
@ -84,7 +102,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
out['c_concat'] = [torch.cat(c_concat)] out['c_concat'] = [torch.cat(c_concat)]
return out return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in): def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0 out_count = torch.ones_like(x_in)/100000.0
@ -96,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
to_run = [] to_run = []
for x in cond: for x in cond:
p = get_area_and_mult(x, x_in, cond_concat_in) p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None: if p is None:
continue continue
to_run += [(p, COND)] to_run += [(p, COND)]
for x in uncond: for x in uncond:
p = get_area_and_mult(x, x_in, cond_concat_in) p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None: if p is None:
continue continue
@ -113,9 +131,8 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
first_shape = first[0][0].shape first_shape = first[0][0].shape
to_batch_temp = [] to_batch_temp = []
for x in range(len(to_run)): for x in range(len(to_run)):
if to_run[x][0][0].shape == first_shape: if can_concat_cond(to_run[x][0], first[0]):
if cond_equal_size(to_run[x][0][2], first[0][2]): to_batch_temp += [x]
to_batch_temp += [x]
to_batch_temp.reverse() to_batch_temp.reverse()
to_batch = to_batch_temp[:1] to_batch = to_batch_temp[:1]
@ -131,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
c = [] c = []
cond_or_uncond = [] cond_or_uncond = []
area = [] area = []
control = None
for x in to_batch: for x in to_batch:
o = to_run.pop(x) o = to_run.pop(x)
p = o[0] p = o[0]
@ -139,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
c += [p[2]] c += [p[2]]
area += [p[3]] area += [p[3]]
cond_or_uncond += [o[1]] cond_or_uncond += [o[1]]
control = p[4]
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x) input_x = torch.cat(input_x)
c = cond_cat(c) c = cond_cat(c)
sigma_ = torch.cat([sigma] * batch_chunks) timestep_ = torch.cat([timestep] * batch_chunks)
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks) if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'])
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x del input_x
for o in range(batch_chunks): for o in range(batch_chunks):
@ -166,10 +188,29 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
max_total_area = model_management.maximum_batch_area() max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
class CFGDenoiserComplex(torch.nn.Module):
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond, **kwargs)
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
return out
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
@ -177,7 +218,7 @@ class CFGDenoiserComplex(torch.nn.Module):
if denoise_mask is not None: if denoise_mask is not None:
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat) out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
if denoise_mask is not None: if denoise_mask is not None:
out *= denoise_mask out *= denoise_mask
@ -196,8 +237,6 @@ def simple_scheduler(model, steps):
def blank_inpaint_image_like(latent_image): def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image) blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space # these are the values for "zero" in pixel space translated to latent space
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
blank_image[:,0] *= 0.8223 blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876 blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364 blank_image[:,2] *= 0.6364
@ -234,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy() n = c[1].copy()
conds += [[smallest[0], n]] conds += [[smallest[0], n]]
def apply_control_net_to_equal_area(conds, uncond):
cond_cnets = []
cond_other = []
uncond_cnets = []
uncond_other = []
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
cond_cnets.append(x[1]['control'])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
uncond_cnets.append(x[1]['control'])
else:
uncond_other.append((x, t))
if len(uncond_cnets) > 0:
return
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if 'control' in o[1] and o[1]['control'] is not None:
n = o[1].copy()
n['control'] = cond_cnets[x]
uncond += [[o[0], n]]
else:
n = o[1].copy()
n['control'] = cond_cnets[x]
uncond[temp[1]] = [o[0], n]
class KSampler: class KSampler:
SCHEDULERS = ["karras", "normal", "simple"] SCHEDULERS = ["karras", "normal", "simple"]
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral", SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
@ -242,11 +317,13 @@ class KSampler:
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
self.model = model self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v": if self.model.parameterization == "v":
self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True) self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else: else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True) self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_k = CFGDenoiserComplex(self.model_wrap) self.model_wrap.parameterization = self.model.parameterization
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device self.device = device
if scheduler not in self.SCHEDULERS: if scheduler not in self.SCHEDULERS:
scheduler = self.SCHEDULERS[0] scheduler = self.SCHEDULERS[0]
@ -316,6 +393,8 @@ class KSampler:
for c in negative: for c in negative:
create_cond_with_same_area_if_none(positive, c) create_cond_with_same_area_if_none(positive, c)
apply_control_net_to_equal_area(positive, negative)
if self.model.model.diffusion_model.dtype == torch.float16: if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:

76
comfy/sd.py

@ -6,6 +6,9 @@ import model_management
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.autoencoder import AutoencoderKL from ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf from omegaconf import OmegaConf
from .cldm import cldm
from . import utils
def load_torch_file(ckpt): def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
@ -323,6 +326,79 @@ class VAE:
samples = samples.cpu() samples = samples.cpu()
return samples return samples
class ControlNet:
def __init__(self, control_model):
self.control_model = control_model
self.cond_hint_original = None
self.cond_hint = None
def get_control(self, x_noisy, t, cond_txt):
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
print("set cond_hint", self.cond_hint.shape)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
return control
def set_cond_hint(self, cond_hint):
self.cond_hint_original = cond_hint
return self
def cleanup(self):
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
def copy(self):
c = ControlNet(self.control_model)
c.cond_hint_original = self.cond_hint_original
return c
def load_controlnet(ckpt_path):
controlnet_data = load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False
sd2 = False
key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
if pth_key in controlnet_data:
pth = True
key = pth_key
elif key in controlnet_data:
pass
else:
print("error checkpoint does not contain controlnet data", ckpt_path)
return None
context_dim = controlnet_data[key].shape[1]
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
model_channels=320,
attention_resolutions=[ 4, 2, 1 ],
num_res_blocks=2,
channel_mult=[ 1, 2, 4, 4 ],
num_heads=8,
use_spatial_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
legacy=False)
if pth:
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.control_model = control_model
w.load_state_dict(controlnet_data, strict=False)
else:
control_model.load_state_dict(controlnet_data, strict=False)
control = ControlNet(control_model)
return control
def load_clip(ckpt_path, embedding_directory=None): def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path) clip_data = load_torch_file(ckpt_path)
config = {} config = {}

18
comfy/utils.py

@ -0,0 +1,18 @@
import torch
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)

93
nodes.py

@ -15,10 +15,12 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
import comfy.samplers import comfy.samplers
import comfy.sd import comfy.sd
import comfy.utils
import model_management import model_management
supported_ckpt_extensions = ['.ckpt'] supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin'] supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
try: try:
import safetensors.torch import safetensors.torch
supported_ckpt_extensions += ['.safetensors'] supported_ckpt_extensions += ['.safetensors']
@ -77,12 +79,14 @@ class ConditioningSetArea:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0): def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
c = copy.deepcopy(conditioning) c = []
for t in c: for t in conditioning:
t[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n = [t[0], t[1].copy()]
t[1]['strength'] = strength n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
t[1]['min_sigma'] = min_sigma n[1]['strength'] = strength
t[1]['max_sigma'] = max_sigma n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma
c.append(n)
return (c, ) return (c, )
class VAEDecode: class VAEDecode:
@ -134,7 +138,6 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask): def encode(self, vae, pixels, mask):
print(pixels.shape, mask.shape)
x = (pixels.shape[1] // 64) * 64 x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64 y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
@ -144,7 +147,6 @@ class VAEEncodeForInpaint:
#shave off a few pixels to keep things seamless #shave off a few pixels to keep things seamless
kernel_tensor = torch.ones((1, 1, 6, 6)) kernel_tensor = torch.ones((1, 1, 6, 6))
mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1) mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1)
print(mask_erosion.shape, pixels.shape)
for i in range(3): for i in range(3):
pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round() pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round()
@ -211,6 +213,44 @@ class VAELoader:
vae = comfy.sd.VAE(ckpt_path=vae_path) vae = comfy.sd.VAE(ckpt_path=vae_path)
return (vae,) return (vae,)
class ControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
def load_controlnet(self, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path)
return (controlnet,)
class ControlNetApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ), "control_net": ("CONTROL_NET", ), "image": ("IMAGE", )}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning"
def apply_controlnet(self, conditioning, control_net, image):
c = []
control_hint = image.movedim(-1,1)
print(control_hint.shape)
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['control'] = control_net.copy().set_cond_hint(control_hint)
c.append(n)
return (c, )
class CLIPLoader: class CLIPLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip") clip_dir = os.path.join(models_dir, "clip")
@ -248,22 +288,7 @@ class EmptyLatentImage:
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return ({"samples":latent}, ) return ({"samples":latent}, )
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
class LatentUpscale: class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area"] upscale_methods = ["nearest-exact", "bilinear", "area"]
@ -282,7 +307,7 @@ class LatentUpscale:
def upscale(self, samples, upscale_method, width, height, crop): def upscale(self, samples, upscale_method, width, height, crop):
s = samples.copy() s = samples.copy()
s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
return (s,) return (s,)
class LatentRotate: class LatentRotate:
@ -461,19 +486,26 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
positive_copy = [] positive_copy = []
negative_copy = [] negative_copy = []
control_nets = []
for p in positive: for p in positive:
t = p[0] t = p[0]
if t.shape[0] < noise.shape[0]: if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0]) t = torch.cat([t] * noise.shape[0])
t = t.to(device) t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
positive_copy += [[t] + p[1:]] positive_copy += [[t] + p[1:]]
for n in negative: for n in negative:
t = n[0] t = n[0]
if t.shape[0] < noise.shape[0]: if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0]) t = torch.cat([t] * noise.shape[0])
t = t.to(device) t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
negative_copy += [[t] + n[1:]] negative_copy += [[t] + n[1:]]
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets)))
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
else: else:
@ -482,6 +514,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu() samples = samples.cpu()
for c in control_nets:
c.cleanup()
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
return (out, ) return (out, )
@ -676,7 +711,7 @@ class ImageScale:
def upscale(self, image, upscale_method, width, height, crop): def upscale(self, image, upscale_method, width, height, crop):
samples = image.movedim(-1,1) samples = image.movedim(-1,1)
s = common_upscale(samples, width, height, upscale_method, crop) s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
s = s.movedim(1,-1) s = s.movedim(1,-1)
return (s,) return (s,)
@ -704,6 +739,8 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop": LatentCrop, "LatentCrop": LatentCrop,
"LoraLoader": LoraLoader, "LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader, "CLIPLoader": CLIPLoader,
"ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader,
} }

Loading…
Cancel
Save