Browse Source

Merge branch 'comfyanonymous:master' into handle-exec-interrupt-msg

pull/3019/head
Suk-Hyun Cho 7 months ago committed by GitHub
parent
commit
058210f78d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      README.md
  2. 14
      comfy/controlnet.py
  3. 19
      comfy/diffusers_convert.py
  4. 2
      comfy/k_diffusion/sampling.py
  5. 2
      comfy/latent_formats.py
  6. 9
      comfy/ldm/cascade/stage_a.py
  7. 14
      comfy/lora.py
  8. 109
      comfy/model_base.py
  9. 20
      comfy/model_detection.py
  10. 116
      comfy/model_management.py
  11. 171
      comfy/model_patcher.py
  12. 3
      comfy/model_sampling.py
  13. 26
      comfy/ops.py
  14. 90
      comfy/sample.py
  15. 76
      comfy/sampler_helpers.py
  16. 273
      comfy/samplers.py
  17. 5
      comfy/sd.py
  18. 83
      comfy/supported_models.py
  19. 11
      comfy/supported_models_base.py
  20. 309
      comfy_extras/nodes_custom_sampler.py
  21. 6
      comfy_extras/nodes_images.py
  22. 45
      comfy_extras/nodes_ip2p.py
  23. 11
      comfy_extras/nodes_model_merging.py
  24. 60
      comfy_extras/nodes_model_merging_model_specific.py
  25. 9
      comfy_extras/nodes_perpneg.py
  26. 4
      comfy_extras/nodes_post_processing.py
  27. 2
      comfy_extras/nodes_sag.py
  28. 53
      comfy_extras/nodes_stable3d.py
  29. 2
      comfy_extras/nodes_stable_cascade.py
  30. 28
      comfy_extras/nodes_video_model.py
  31. 2
      cuda_malloc.py
  32. 6
      custom_nodes/websocket_image_save.py
  33. 2
      execution.py
  34. 17
      folder_paths.py
  35. 1
      main.py
  36. 10
      node_helpers.py
  37. 62
      nodes.py
  38. 159
      script_examples/websockets_api_example_ws_images.py
  39. 2
      tests-ui/tests/groupNode.test.js
  40. 4
      web/extensions/core/colorPalette.js
  41. 2
      web/lib/litegraph.core.js
  42. 14
      web/scripts/pnginfo.js

2
README.md

@ -142,7 +142,7 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly). 1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. 1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). 1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly. 1. Launch ComfyUI by running `python main.py`
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).

14
comfy/controlnet.py

@ -138,11 +138,13 @@ class ControlBase:
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
@ -183,7 +185,9 @@ class ControlNet(ControlBase):
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(None, control, control_prev, output_dtype)
def copy(self): def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c) self.copy_to(c)
return c return c
@ -201,7 +205,7 @@ class ControlNet(ControlBase):
super().cleanup() super().cleanup()
class ControlLoraOps: class ControlLoraOps:
class Linear(torch.nn.Module): class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None: device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
@ -220,7 +224,7 @@ class ControlLoraOps:
else: else:
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module): class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__( def __init__(
self, self,
in_channels, in_channels,

19
comfy/diffusers_convert.py

@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2} code2idx = {"q": 0, "k": 1, "v": 2}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors):
x = 0
for t in tensors:
x += t.shape[0]
shape = [x] + list(tensors[0].shape)[1:]
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
x = 0
for t in tensors:
out[x:x + t.shape[0]] = t
x += t.shape[0]
return out
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {} new_state_dict = {}
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
if None in tensors: if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
for k_pre, tensors in capture_qkv_bias.items(): for k_pre, tensors in capture_qkv_bias.items():
if None in tensors: if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
return new_state_dict return new_state_dict

2
comfy/k_diffusion/sampling.py

@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
x = denoised x = denoised
if sigmas[i + 1] > 0: if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
return x return x

2
comfy/latent_formats.py

@ -95,7 +95,7 @@ class SC_Prior(LatentFormat):
class SC_B(LatentFormat): class SC_B(LatentFormat):
def __init__(self): def __init__(self):
self.scale_factor = 1.0 self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [ self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023], [ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195], [-0.2093, -0.0222, -0.0195],

9
comfy/ldm/cascade/stage_a.py

@ -163,11 +163,9 @@ class ResBlock(nn.Module):
class StageA(nn.Module): class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
scale_factor=0.43): # 0.3764
super().__init__() super().__init__()
self.c_latent = c_latent self.c_latent = c_latent
self.scale_factor = scale_factor
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
# Encoder blocks # Encoder blocks
@ -214,12 +212,11 @@ class StageA(nn.Module):
x = self.down_blocks(x) x = self.down_blocks(x)
if quantize: if quantize:
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 return qe, x, indices, vq_loss + commit_loss * 0.25
else: else:
return x / self.scale_factor return x
def decode(self, x): def decode(self, x):
x = x * self.scale_factor
x = self.up_blocks(x) x = self.up_blocks(x)
x = self.out_block(x) x = self.out_block(x)
return x return x

14
comfy/lora.py

@ -21,6 +21,12 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item() alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name) loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
regular_lora = "{}.lora_up.weight".format(x) regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
@ -44,7 +50,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
@ -65,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name) loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
@ -117,7 +123,7 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name) loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
#glora #glora
a1_name = "{}.a1.weight".format(x) a1_name = "{}.a1.weight".format(x)
@ -125,7 +131,7 @@ def load_lora(lora, to_load):
b1_name = "{}.b1.weight".format(x) b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x) b2_name = "{}.b2.weight".format(x)
if a1_name in lora: if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name) loaded_keys.add(a1_name)
loaded_keys.add(a2_name) loaded_keys.add(a2_name)
loaded_keys.add(b1_name) loaded_keys.add(b1_name)

109
comfy/model_base.py

@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module):
self.adm_channels = unet_config.get("adm_in_channels", None) self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
self.inpaint_model = False
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name)) logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
if self.inpaint_model: if len(self.concat_keys) > 0:
concat_keys = ("mask", "masked_image")
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
concat_latent_image = kwargs.get("concat_latent_image", None) concat_latent_image = kwargs.get("concat_latent_image", None)
@ -125,24 +125,16 @@ class BaseModel(torch.nn.Module):
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape): if denoise_mask is not None:
denoise_mask = denoise_mask[:,:1] if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image): denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
blank_image = torch.ones_like(latent_image) if denoise_mask.shape[-2:] != noise.shape[-2:]:
# these are the values for "zero" in pixel space translated to latent space denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
blank_image[:,0] *= 0.8223 denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
for ck in concat_keys: for ck in self.concat_keys:
if denoise_mask is not None: if denoise_mask is not None:
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask.to(device)) cond_concat.append(denoise_mask.to(device))
@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module):
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data) out['c_concat'] = comfy.conds.CONDNoiseShape(data)
@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module):
return unet_state_dict return unet_state_dict
def set_inpaint(self): def set_inpaint(self):
self.inpaint_model = True self.concat_keys = ("mask", "masked_image")
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like
def memory_required(self, input_shape): def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
@ -380,6 +381,36 @@ class SVD_img2vid(BaseModel):
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out return out
class SV3D_u(SVD_img2vid):
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
class SV3D_p(SVD_img2vid):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder_512 = Timestep(512)
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
azimuth = kwargs.get("azimuth", 0)
noise = kwargs.get("noise", None)
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
return torch.cat(out, dim=1)
class Stable_Zero123(BaseModel): class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@ -442,6 +473,42 @@ class SD_X4Upscaler(BaseModel):
out['y'] = comfy.conds.CONDRegular(noise_level) out['y'] = comfy.conds.CONDRegular(noise_level)
return out return out
class IP2P:
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
if model_type == ModelType.V_PREDICTION_EDM:
self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
class StableCascade_C(BaseModel): class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC) super().__init__(model_config, model_type, device=device, unet_model=StageC)

20
comfy/model_detection.py

@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
return unet_config return unet_config
def model_config_from_unet_config(unet_config): def model_config_from_unet_config(unet_config, state_dict=None):
for model_config in comfy.supported_models.models: for model_config in comfy.supported_models.models:
if model_config.matches(unet_config): if model_config.matches(unet_config, state_dict):
return model_config(unet_config) return model_config(unet_config)
logging.error("no match {}".format(unet_config)) logging.error("no match {}".format(unet_config))
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix) unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config) model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config) return comfy.supported_models_base.BASE(unet_config)
else: else:
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
@ -345,7 +351,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B] SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

116
comfy/model_management.py

@ -272,8 +272,9 @@ def module_size(module):
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
self.model_accelerated = False
self.device = model.load_device self.device = model.load_device
self.weights_loaded = False
self.real_model = None
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
@ -285,54 +286,34 @@ class LoadedModel:
return self.model_memory() return self.model_memory()
def model_load(self, lowvram_model_memory=0): def model_load(self, lowvram_model_memory=0):
patch_model_to = None patch_model_to = self.device
if lowvram_model_memory == 0:
patch_model_to = self.device
self.model.model_patches_to(self.device) self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype()) self.model.model_patches_to(self.model.model_dtype())
load_weights = not self.weights_loaded
try: try:
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e: except Exception as e:
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
self.model_unload() self.model_unload()
raise e raise e
if lowvram_model_memory > 0:
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
logging.warning("lowvram: loaded module regularly {}".format(m))
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
self.weights_loaded = True
return self.real_model return self.real_model
def model_unload(self): def model_unload(self, unpatch_weights=True):
if self.model_accelerated: self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
@ -340,31 +321,57 @@ class LoadedModel:
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) return (1024 * 1024 * 1024)
def unload_model_clones(model): def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [] to_unload = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model): if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload to_unload = [i] + to_unload
if len(to_unload) == 0:
return True
same_weights = 0
for i in to_unload: for i in to_unload:
logging.debug("unload clone {}".format(i)) if model.clone_has_same_weights(current_loaded_models[i].model):
current_loaded_models.pop(i).model_unload() same_weights += 1
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False unloaded_model = []
can_unload = []
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded:
m = current_loaded_models.pop(i) can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
m.model_unload()
del m for x in sorted(can_unload):
unloaded_model = True i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
current_loaded_models.pop(i)
if unloaded_model: if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
else: else:
if vram_state != VRAMState.HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
@ -378,6 +385,8 @@ def load_models_gpu(models, memory_required=0):
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) extra_mem = max(inference_memory, memory_required)
models = set(models)
models_to_load = [] models_to_load = []
models_already_loaded = [] models_already_loaded = []
for x in models: for x in models:
@ -403,13 +412,18 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
unload_model_clones(loaded_model.model) if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load: for loaded_model in models_to_load:
model = loaded_model.model model = loaded_model.model
torch_dev = model.load_device torch_dev = model.load_device
@ -438,11 +452,15 @@ def load_models_gpu(models, memory_required=0):
def load_model_gpu(model): def load_model_gpu(model):
return load_models_gpu([model]) return load_models_gpu([model])
def cleanup_models(): def cleanup_models(keep_clone_weights_loaded=False):
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2: if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete: for i in to_delete:
x = current_loaded_models.pop(i) x = current_loaded_models.pop(i)

171
comfy/model_patcher.py

@ -2,10 +2,23 @@ import torch
import copy import copy
import inspect import inspect
import logging import logging
import uuid
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
def apply_weight_decompose(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm)
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size self.size = size
@ -24,6 +37,8 @@ class ModelPatcher:
self.current_device = current_device self.current_device = current_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4()
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
@ -38,10 +53,13 @@ class ModelPatcher:
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy() n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options) n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys n.model_keys = self.model_keys
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
return n return n
def is_clone(self, other): def is_clone(self, other):
@ -49,6 +67,19 @@ class ModelPatcher:
return True return True
return False return False
def clone_has_same_weights(self, clone):
if not self.is_clone(clone):
return False
if len(self.patches) == 0 and len(clone.patches) == 0:
return True
if self.patches_uuid == clone.patches_uuid:
if len(self.patches) != len(clone.patches):
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
else:
return True
def memory_required(self, input_shape): def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape) return self.model.memory_required(input_shape=input_shape)
@ -119,6 +150,15 @@ class ModelPatcher:
def add_object_patch(self, name, obj): def add_object_patch(self, name, obj):
self.object_patches[name] = obj self.object_patches[name] = obj
def get_model_object(self, name):
if name in self.object_patches:
return self.object_patches[name]
else:
if name in self.object_patches_backup:
return self.object_patches_backup[name]
else:
return comfy.utils.get_attr(self.model, name)
def model_patches_to(self, device): def model_patches_to(self, device):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" in to: if "patches" in to:
@ -153,6 +193,7 @@ class ModelPatcher:
current_patches.append((strength_patch, patches[k], strength_model)) current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches self.patches[k] = current_patches
self.patches_uuid = uuid.uuid4()
return list(p) return list(p)
def get_key_patches(self, filter_prefix=None): def get_key_patches(self, filter_prefix=None):
@ -178,6 +219,27 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_weight_to_device(self, key, device_to=None):
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True): def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches: for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
@ -191,23 +253,7 @@ class ModelPatcher:
logging.warning("could not patch. key doesn't exist in model: {}".format(key)) logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue continue
weight = model_sd[key] self.patch_weight_to_device(key, device_to)
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
del temp_weight
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
@ -215,6 +261,47 @@ class ModelPatcher:
return self.model return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if lowvram_weight:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key, device_to)
self.patch_weight_to_device(bias_key, device_to)
m.to(device_to)
mem_counter += comfy.model_management.module_size(m)
logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True
return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
for p in patches: for p in patches:
alpha = p[0] alpha = p[0]
@ -243,6 +330,7 @@ class ModelPatcher:
elif patch_type == "lora": #lora/locon elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None: if v[2] is not None:
alpha *= v[2] / mat2.shape[0] alpha *= v[2] / mat2.shape[0]
if v[3] is not None: if v[3] is not None:
@ -252,6 +340,8 @@ class ModelPatcher:
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try: try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr": elif patch_type == "lokr":
@ -262,6 +352,7 @@ class ModelPatcher:
w2_a = v[5] w2_a = v[5]
w2_b = v[6] w2_b = v[6]
t2 = v[7] t2 = v[7]
dora_scale = v[8]
dim = None dim = None
if w1 is None: if w1 is None:
@ -291,6 +382,8 @@ class ModelPatcher:
try: try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha": elif patch_type == "loha":
@ -300,6 +393,7 @@ class ModelPatcher:
alpha *= v[2] / w1b.shape[0] alpha *= v[2] / w1b.shape[0]
w2a = v[3] w2a = v[3]
w2b = v[4] w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition if v[5] is not None: #cp decomposition
t1 = v[5] t1 = v[5]
t2 = v[6] t2 = v[6]
@ -320,12 +414,16 @@ class ModelPatcher:
try: try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora": elif patch_type == "glora":
if v[4] is not None: if v[4] is not None:
alpha *= v[4] / v[0].shape[0] alpha *= v[4] / v[0].shape[0]
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
@ -333,6 +431,8 @@ class ModelPatcher:
try: try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
else: else:
@ -340,24 +440,35 @@ class ModelPatcher:
return weight return weight
def unpatch_model(self, device_to=None): def unpatch_model(self, device_to=None, unpatch_weights=True):
keys = list(self.backup.keys()) if unpatch_weights:
if self.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
if self.weight_inplace_update: self.model_lowvram = False
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr_param(self.model, k, self.backup[k])
self.backup = {} keys = list(self.backup.keys())
if device_to is not None: if self.weight_inplace_update:
self.model.to(device_to) for k in keys:
self.current_device = device_to comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr_param(self.model, k, self.backup[k])
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys()) keys = list(self.object_patches_backup.keys())
for k in keys: for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {} self.object_patches_backup.clear()

3
comfy/model_sampling.py

@ -20,6 +20,9 @@ class EPS:
noise += latent_image noise += latent_image
return noise return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class V_PREDICTION(EPS): class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))

26
comfy/ops.py

@ -24,13 +24,20 @@ def cast_bias_weight(s, input):
non_blocking = comfy.model_management.device_supports_non_blocking(input.device) non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None: if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.bias_function is not None:
bias = s.bias_function(bias)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.weight_function is not None:
weight = s.weight_function(weight)
return weight, bias return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = None
bias_function = None
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear, CastWeightBiasOp):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -44,8 +51,7 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -59,8 +65,7 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d): class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -74,8 +79,7 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm): class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -90,8 +94,7 @@ class disable_weight_init:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm): class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -109,8 +112,7 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d): class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None

90
comfy/sample.py

@ -1,10 +1,9 @@
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.samplers import comfy.samplers
import comfy.conds
import comfy.utils import comfy.utils
import math
import numpy as np import numpy as np
import logging
def prepare_noise(latent_image, seed, noise_inds=None): def prepare_noise(latent_image, seed, noise_inds=None):
""" """
@ -25,94 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises = torch.cat(noises, axis=0) noises = torch.cat(noises, axis=0)
return noises return noises
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, positive, negative, noise_mask): def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
positive = convert_cond(positive) return model, positive, negative, noise_mask, []
negative = convert_cond(negative)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, positive, negative, noise_mask, models
def cleanup_additional_models(models):
logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples

76
comfy/sampler_helpers.py

@ -0,0 +1,76 @@
import torch
import comfy.model_management
import comfy.conds
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, conds):
device = model.load_device
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, models
def cleanup_models(conds, models):
cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))

273
comfy/samplers.py

@ -5,6 +5,7 @@ import collections
from comfy import model_management from comfy import model_management
import math import math
import logging import logging
import comfy.sampler_helpers
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
@ -127,30 +128,23 @@ def cond_cat(c_list):
return out return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def calc_cond_batch(model, conds, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in) out_conds = []
out_count = torch.ones_like(x_in) * 1e-37 out_counts = []
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = [] to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)] for i in range(len(conds)):
if uncond is not None: out_conds.append(torch.zeros_like(x_in))
for x in uncond: out_counts.append(torch.ones_like(x_in) * 1e-37)
p = get_area_and_mult(x, x_in, timestep)
if p is None: cond = conds[i]
continue if cond is not None:
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)] to_run += [(p, i)]
while len(to_run) > 0: while len(to_run) > 0:
first = to_run[0] first = to_run[0]
@ -222,74 +216,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks): for o in range(batch_chunks):
if cond_or_uncond[o] == COND: cond_index = cond_or_uncond[o]
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] for i in range(len(out_conds)):
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] out_conds[i] /= out_counts[i]
del mult
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
out_cond /= out_count return cfg_result
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None uncond_ = None
else: else:
uncond_ = uncond uncond_ = uncond
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []): conds = [cond, uncond_]
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, out = calc_cond_batch(model, conds, x, timestep, model_options)
"sigma": timestep, "model_options": model_options, "input": x} return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
cfg_result = fn(args)
return cfg_result
class CFGNoisePredictor(torch.nn.Module): class KSamplerX0Inpaint:
def __init__(self, model):
super().__init__()
self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model, sigmas): def __init__(self, model, sigmas):
super().__init__()
self.inner_model = model self.inner_model = model
self.sigmas = sigmas self.sigmas = sigmas
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None: if denoise_mask is not None:
if "denoise_mask_function" in model_options: if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None: if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask out = out * denoise_mask + self.latent_image * latent_mask
return out return out
def simple_scheduler(model, steps): def simple_scheduler(model_sampling, steps):
s = model.model_sampling s = model_sampling
sigs = [] sigs = []
ss = len(s.sigmas) / steps ss = len(s.sigmas) / steps
for x in range(steps): for x in range(steps):
@ -297,8 +283,8 @@ def simple_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps): def ddim_scheduler(model_sampling, steps):
s = model.model_sampling s = model_sampling
sigs = [] sigs = []
ss = max(len(s.sigmas) // steps, 1) ss = max(len(s.sigmas) // steps, 1)
x = 1 x = 1
@ -309,8 +295,8 @@ def ddim_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def normal_scheduler(model, steps, sgm=False, floor=False): def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
s = model.model_sampling s = model_sampling
start = s.timestep(s.sigma_max) start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min) end = s.timestep(s.sigma_min)
@ -546,6 +532,7 @@ class KSAMPLER(Sampler):
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples return samples
@ -559,72 +546,133 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive": elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
sigma_min = sigmas[-1] sigma_min = sigmas[-1]
if sigma_min == 0: if sigma_min == 0:
sigma_min = sigmas[-2] sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
sampler_function = dpm_adaptive_function sampler_function = dpm_adaptive_function
else: else:
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
return KSAMPLER(sampler_function, extra_options, inpaint_options) return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
return model_denoise
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
positive = positive[:] for k in conds:
negative = negative[:] conds[k] = conds[k][:]
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) for k in conds:
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) calculate_start_end_timesteps(model, conds[k])
model_wrap = wrap_model(model) if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area
for k in conds:
for c in conds[k]:
for kk in conds:
if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)
for k in conds:
pre_run_control(model, conds[k])
if "positive" in conds:
positive = conds["positive"]
for k in conds:
if k != "positive":
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
calculate_start_end_timesteps(model, negative) return conds
calculate_start_end_timesteps(model, positive)
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. class CFGGuider:
latent_image = model.process_latent_in(latent_image) def __init__(self, model_patcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
if hasattr(model, 'extra_conds'): def set_conds(self, positive, negative):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) self.inner_set_conds({"positive": positive, "negative": negative})
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area def set_cfg(self, cfg):
for c in positive: self.cfg = cfg
create_cond_with_same_area_if_none(negative, c)
for c in negative: def inner_set_conds(self, conds):
create_cond_with_same_area_if_none(positive, c) for k in conds:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
pre_run_control(model, negative + positive) extra_args = {"model_options": self.model_options, "seed":seed}
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) return self.inner_model.process_latent_out(samples.to(torch.float32))
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps): def calculate_sigmas(model_sampling, scheduler_name, steps):
if scheduler_name == "karras": if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential": elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "normal": elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps) sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple": elif scheduler_name == "simple":
sigmas = simple_scheduler(model, steps) sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "ddim_uniform": elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model, steps) sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True) sigmas = normal_scheduler(model_sampling, steps, sgm=True)
else: else:
logging.error("error invalid scheduler {}".format(scheduler_name)) logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas return sigmas
@ -666,7 +714,7 @@ class KSampler:
steps += 1 steps += 1
discard_penultimate_sigma = True discard_penultimate_sigma = True
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
if discard_penultimate_sigma: if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
@ -677,9 +725,12 @@ class KSampler:
if denoise is None or denoise > 0.9999: if denoise is None or denoise > 0.9999:
self.sigmas = self.calculate_sigmas(steps).to(self.device) self.sigmas = self.calculate_sigmas(steps).to(self.device)
else: else:
new_steps = int(steps/denoise) if denoise <= 0.0:
sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = torch.FloatTensor([])
self.sigmas = sigmas[-(steps + 1):] else:
new_steps = int(steps/denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
if sigmas is None: if sigmas is None:

5
comfy/sd.py

@ -600,7 +600,7 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None clip_sd = None
load_models = [model] load_models = [model]
if clip is not None: if clip is not None:
@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
model_management.load_models_gpu(load_models) model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]
comfy.utils.save_torch_file(sd, output_path, metadata=metadata) comfy.utils.save_torch_file(sd, output_path, metadata=metadata)

83
comfy/supported_models.py

@ -45,6 +45,11 @@ class SD15(supported_models_base.BASE):
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
for p in pop_keys:
if p in state_dict:
state_dict.pop(p)
replace_prefix = {"clip_l.": "cond_stage_model."} replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@ -65,8 +70,8 @@ class SD20(supported_models_base.BASE):
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k] out = state_dict.get(k, None)
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
return model_base.ModelType.EPS return model_base.ModelType.EPS
@ -169,6 +174,11 @@ class SDXL(supported_models_base.BASE):
self.sampling_settings["sigma_max"] = 80.0 self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002 self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM return model_base.ModelType.EDM
elif "edm_vpred.sigma_max" in state_dict:
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
if "edm_vpred.sigma_min" in state_dict:
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
return model_base.ModelType.V_PREDICTION_EDM
elif "v_pred" in state_dict: elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
else: else:
@ -279,6 +289,41 @@ class SVD_img2vid(supported_models_base.BASE):
def clip_target(self): def clip_target(self):
return None return None
class SV3D_u(SVD_img2vid):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 256,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
vae_key_prefix = ["conditioner.embedders.1.encoder."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_u(self, device=device)
return out
class SV3D_p(SV3D_u):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 1280,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_p(self, device=device)
return out
class Stable_Zero123(supported_models_base.BASE): class Stable_Zero123(supported_models_base.BASE):
unet_config = { unet_config = {
"context_dim": 768, "context_dim": 768,
@ -294,6 +339,11 @@ class Stable_Zero123(supported_models_base.BASE):
"num_head_channels": -1, "num_head_channels": -1,
} }
required_keys = {
"cc_projection.weight": None,
"cc_projection.bias": None,
}
clip_vision_prefix = "cond_stage_model.model.visual." clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
@ -399,6 +449,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
out = model_base.StableCascade_B(self, device=device) out = model_base.StableCascade_B(self, device=device)
return out return out
class SD15_instructpix2pix(SD15):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SD15_instructpix2pix(self, device=device)
class SDXL_instructpix2pix(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models += [SVD_img2vid] models += [SVD_img2vid]

11
comfy/supported_models_base.py

@ -16,6 +16,8 @@ class BASE:
"num_head_channels": 64, "num_head_channels": 64,
} }
required_keys = {}
clip_prefix = [] clip_prefix = []
clip_vision_prefix = None clip_vision_prefix = None
noise_aug_config = None noise_aug_config = None
@ -28,10 +30,14 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
@classmethod @classmethod
def matches(s, unet_config): def matches(s, unet_config, state_dict=None):
for k in s.unet_config: for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]: if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False return False
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
return False
return True return True
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
@ -41,7 +47,8 @@ class BASE:
return self.unet_config["in_channels"] > 4 return self.unet_config["in_channels"] > 4
def __init__(self, unet_config): def __init__(self, unet_config):
self.unet_config = unet_config self.unet_config = unet_config.copy()
self.sampling_settings = self.sampling_settings.copy()
self.latent_format = self.latent_format() self.latent_format = self.latent_format()
for x in self.unet_extra_config: for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x] self.unet_config[x] = self.unet_extra_config[x]

309
comfy_extras/nodes_custom_sampler.py

@ -4,6 +4,7 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling
import latent_preview import latent_preview
import torch import torch
import comfy.utils import comfy.utils
import node_helpers
class BasicScheduler: class BasicScheduler:
@ -24,10 +25,11 @@ class BasicScheduler:
def get_sigmas(self, model, scheduler, steps, denoise): def get_sigmas(self, model, scheduler, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = int(steps/denoise) total_steps = int(steps/denoise)
comfy.model_management.load_models_gpu([model]) sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):] sigmas = sigmas[-(steps + 1):]
return (sigmas, ) return (sigmas, )
@ -160,6 +162,9 @@ class FlipSigmas:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas): def get_sigmas(self, sigmas):
if len(sigmas) == 0:
return (sigmas,)
sigmas = sigmas.flip(0) sigmas = sigmas.flip(0)
if sigmas[0] == 0: if sigmas[0] == 0:
sigmas[0] = 0.0001 sigmas[0] = 0.0001
@ -181,6 +186,28 @@ class KSamplerSelect:
sampler = comfy.samplers.sampler_object(sampler_name) sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, ) return (sampler, )
class SamplerDPMPP_3M_SDE:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"noise_device": (['gpu', 'cpu'], ),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise, noise_device):
if noise_device == 'cpu':
sampler_name = "dpmpp_3m_sde"
else:
sampler_name = "dpmpp_3m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerDPMPP_2M_SDE: class SamplerDPMPP_2M_SDE:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -245,6 +272,67 @@ class SamplerEulerAncestral:
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
return (sampler, ) return (sampler, )
class SamplerLMS:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"order": ("INT", {"default": 4, "min": 1, "max": 100}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, order):
sampler = comfy.samplers.ksampler("lms", {"order": order})
return (sampler, )
class SamplerDPMAdaptative:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"order": ("INT", {"default": 3, "min": 2, "max": 3}),
"rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise):
sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
"icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
"s_noise":s_noise })
return (sampler, )
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
def generate_noise(self, input_latent):
latent_image = input_latent["samples"]
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
class Noise_RandomNoise:
def __init__(self, seed):
self.seed = seed
def generate_noise(self, input_latent):
latent_image = input_latent["samples"]
batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
class SamplerCustom: class SamplerCustom:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -272,10 +360,9 @@ class SamplerCustom:
latent = latent_image latent = latent_image
latent_image = latent["samples"] latent_image = latent["samples"]
if not add_noise: if not add_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = Noise_EmptyNoise().generate_noise(latent)
else: else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise_mask = None noise_mask = None
if "noise_mask" in latent: if "noise_mask" in latent:
@ -296,6 +383,207 @@ class SamplerCustom:
out_denoised = out out_denoised = out
return (out, out_denoised) return (out, out_denoised)
class Guider_Basic(comfy.samplers.CFGGuider):
def set_conds(self, positive):
self.inner_set_conds({"positive": positive})
class BasicGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"conditioning": ("CONDITIONING", ),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, conditioning):
guider = Guider_Basic(model)
guider.set_conds(conditioning)
return (guider,)
class CFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, positive, negative, cfg):
guider = comfy.samplers.CFGGuider(model)
guider.set_conds(positive, negative)
guider.set_cfg(cfg)
return (guider,)
class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2):
self.cfg1 = cfg1
self.cfg2 = cfg2
def set_conds(self, positive, middle, negative):
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
def predict_noise(self, x, timestep, model_options={}, seed=None):
negative_cond = self.conds.get("negative", None)
middle_cond = self.conds.get("middle", None)
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
class DualCFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"cond1": ("CONDITIONING", ),
"cond2": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative)
return (guider,)
class DisableNoise:
@classmethod
def INPUT_TYPES(s):
return {"required":{
}
}
RETURN_TYPES = ("NOISE",)
FUNCTION = "get_noise"
CATEGORY = "sampling/custom_sampling/noise"
def get_noise(self):
return (Noise_EmptyNoise(),)
class RandomNoise(DisableNoise):
@classmethod
def INPUT_TYPES(s):
return {"required":{
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
}
}
def get_noise(self, noise_seed):
return (Noise_RandomNoise(noise_seed),)
class SamplerCustomAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"noise": ("NOISE", ),
"guider": ("GUIDER", ),
"sampler": ("SAMPLER", ),
"sigmas": ("SIGMAS", ),
"latent_image": ("LATENT", ),
}
}
RETURN_TYPES = ("LATENT","LATENT")
RETURN_NAMES = ("output", "denoised_output")
FUNCTION = "sample"
CATEGORY = "sampling/custom_sampling"
def sample(self, noise, guider, sampler, sigmas, latent_image):
latent = latent_image
latent_image = latent["samples"]
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
samples = samples.to(comfy.model_management.intermediate_device())
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
out_denoised = latent.copy()
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
return (out, out_denoised)
class AddNoise:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"noise": ("NOISE", ),
"sigmas": ("SIGMAS", ),
"latent_image": ("LATENT", ),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "add_noise"
CATEGORY = "_for_testing/custom_sampling/noise"
def add_noise(self, model, noise, sigmas, latent_image):
if len(sigmas) == 0:
return latent_image
latent = latent_image
latent_image = latent["samples"]
noisy = noise.generate_noise(latent)
model_sampling = model.get_model_object("model_sampling")
process_latent_out = model.get_model_object("process_latent_out")
process_latent_in = model.get_model_object("process_latent_in")
if len(sigmas) > 1:
scale = torch.abs(sigmas[0] - sigmas[-1])
else:
scale = sigmas[0]
if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = process_latent_in(latent_image)
noisy = model_sampling.noise_scaling(scale, noisy, latent_image)
noisy = process_latent_out(noisy)
noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0)
out = latent.copy()
out["samples"] = noisy
return (out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom, "SamplerCustom": SamplerCustom,
"BasicScheduler": BasicScheduler, "BasicScheduler": BasicScheduler,
@ -306,8 +594,19 @@ NODE_CLASS_MAPPINGS = {
"SDTurboScheduler": SDTurboScheduler, "SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect, "KSamplerSelect": KSamplerSelect,
"SamplerEulerAncestral": SamplerEulerAncestral, "SamplerEulerAncestral": SamplerEulerAncestral,
"SamplerLMS": SamplerLMS,
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,
"BasicGuider": BasicGuider,
"RandomNoise": RandomNoise,
"DisableNoise": DisableNoise,
"AddNoise": AddNoise,
"SamplerCustomAdvanced": SamplerCustomAdvanced,
} }

6
comfy_extras/nodes_images.py

@ -37,7 +37,7 @@ class RepeatImageBatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}), "amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat" FUNCTION = "repeat"
@ -52,8 +52,8 @@ class ImageFromBatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), return {"required": { "image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}), "length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "frombatch" FUNCTION = "frombatch"

45
comfy_extras/nodes_ip2p.py

@ -0,0 +1,45 @@
import torch
class InstructPixToPixConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
concat_latent = vae.encode(pixels)
out_latent = {}
out_latent["samples"] = torch.zeros_like(concat_latent)
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

11
comfy_extras/nodes_model_merging.py

@ -2,7 +2,9 @@ import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_base import comfy.model_base
import comfy.model_management import comfy.model_management
import comfy.model_sampling
import torch
import folder_paths import folder_paths
import json import json
import os import os
@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting" # "v2-inpainting"
extra_keys = {}
model_sampling = model.get_model_object("model_sampling")
if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM):
if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION):
extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float()
extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float()
if model.model.model_type == comfy.model_base.ModelType.EPS: if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon" metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave: class CheckpointSave:
def __init__(self): def __init__(self):

60
comfy_extras/nodes_model_merging_model_specific.py

@ -0,0 +1,60 @@
import comfy_extras.nodes_model_merging
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
for i in range(12):
arg_dict["input_blocks.{}.".format(i)] = argument
for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
for i in range(12):
arg_dict["output_blocks.{}.".format(i)] = argument
arg_dict["out."] = argument
return {"required": arg_dict}
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
for i in range(9):
arg_dict["input_blocks.{}".format(i)] = argument
for i in range(3):
arg_dict["middle_block.{}".format(i)] = argument
for i in range(9):
arg_dict["output_blocks.{}".format(i)] = argument
arg_dict["out."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
}

9
comfy_extras/nodes_perpneg.py

@ -1,16 +1,17 @@
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.sample import comfy.sampler_helpers
import comfy.samplers import comfy.samplers
import comfy.utils import comfy.utils
#TODO: This node should be removed and replaced with one that uses the new Guider/SamplerCustomAdvanced.
class PerpNeg: class PerpNeg:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ), return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ), "empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
@ -19,7 +20,7 @@ class PerpNeg:
def patch(self, model, empty_conditioning, neg_scale): def patch(self, model, empty_conditioning, neg_scale):
m = model.clone() m = model.clone()
nocond = comfy.sample.convert_cond(empty_conditioning) nocond = comfy.sampler_helpers.convert_cond(empty_conditioning)
def cfg_function(args): def cfg_function(args):
model = args["model"] model = args["model"]
@ -31,7 +32,7 @@ class PerpNeg:
model_options = args["model_options"] model_options = args["model_options"]
nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond neg = noise_pred_neg - noise_pred_nocond

4
comfy_extras/nodes_post_processing.py

@ -204,13 +204,13 @@ class Sharpen:
"default": 1.0, "default": 1.0,
"min": 0.1, "min": 0.1,
"max": 10.0, "max": 10.0,
"step": 0.1 "step": 0.01
}), }),
"alpha": ("FLOAT", { "alpha": ("FLOAT", {
"default": 1.0, "default": 1.0,
"min": 0.0, "min": 0.0,
"max": 5.0, "max": 5.0,
"step": 0.1 "step": 0.01
}), }),
}, },
} }

2
comfy_extras/nodes_sag.py

@ -150,7 +150,7 @@ class SelfAttentionGuidance:
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred degraded_noised = degraded + x - uncond_pred
# call into the UNet # call into the UNet
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)

53
comfy_extras/nodes_stable3d.py

@ -29,8 +29,8 @@ class StableZero123_Conditioning:
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent") RETURN_NAMES = ("positive", "negative", "latent")
@ -62,10 +62,10 @@ class StableZero123_Conditioning_Batched:
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}} }}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent") RETURN_NAMES = ("positive", "negative", "latent")
@ -95,8 +95,49 @@ class StableZero123_Conditioning_Batched:
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
class SV3D_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
azimuth = 0
azimuth_increment = 360 / (max(video_frames, 2) - 1)
elevations = []
azimuths = []
for i in range(video_frames):
elevations.append(elevation)
azimuths.append(azimuth)
azimuth += azimuth_increment
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning, "StableZero123_Conditioning": StableZero123_Conditioning,
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
"SV3D_Conditioning": SV3D_Conditioning,
} }

2
comfy_extras/nodes_stable_cascade.py

@ -74,7 +74,7 @@ class StableCascade_StageC_VAEEncode:
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
c_latent = vae.encode(s[:,:,:,:3]) c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4]) b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({ return ({
"samples": c_latent, "samples": c_latent,
}, { }, {

28
comfy_extras/nodes_video_model.py

@ -79,6 +79,33 @@ class VideoLinearCFGGuidance:
m.set_model_sampler_cfg_function(linear_cfg) m.set_model_sampler_cfg_function(linear_cfg)
return (m, ) return (m, )
class VideoTriangleCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
period = 1.0
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
@ -98,6 +125,7 @@ NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
} }

2
cuda_malloc.py

@ -47,7 +47,7 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630" "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
} }
def cuda_malloc_supported(): def cuda_malloc_supported():

6
custom_nodes/websocket_image_save.py.disabled → custom_nodes/websocket_image_save.py

@ -10,10 +10,6 @@ import time
#binary images on the websocket with a 8 byte header indicating the type #binary images on the websocket with a 8 byte header indicating the type
#of binary message (first 4 bytes) and the image format (next 4 bytes). #of binary message (first 4 bytes) and the image format (next 4 bytes).
#The reason this node is disabled by default is because there is a small
#issue when using it with the default ComfyUI web interface: When generating
#batches only the last image will be shown in the UI.
#Note that no metadata will be put in the images saved with this node. #Note that no metadata will be put in the images saved with this node.
class SaveImageWebsocket: class SaveImageWebsocket:
@ -28,7 +24,7 @@ class SaveImageWebsocket:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "image" CATEGORY = "api/image"
def save_images(self, images): def save_images(self, images):
pbar = comfy.utils.ProgressBar(images.shape[0]) pbar = comfy.utils.ProgressBar(images.shape[0])

2
execution.py

@ -368,7 +368,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x) d = self.outputs_ui.pop(x)
del d del d
comfy.model_management.cleanup_models() comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached", self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id}, { "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False) broadcast=False)

17
folder_paths.py

@ -1,5 +1,6 @@
import os import os
import time import time
import logging
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
@ -44,7 +45,7 @@ if not os.path.exists(input_directory):
try: try:
os.makedirs(input_directory) os.makedirs(input_directory)
except: except:
print("Failed to create input directory") logging.error("Failed to create input directory")
def set_output_directory(output_dir): def set_output_directory(output_dir):
global output_directory global output_directory
@ -146,21 +147,23 @@ def recursive_search(directory, excluded_dir_names=None):
try: try:
dirs[directory] = os.path.getmtime(directory) dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError: except FileNotFoundError:
print(f"Warning: Unable to access {directory}. Skipping this path.") logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
logging.debug("recursive file list on directory {}".format(directory))
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames: for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path) result.append(relative_path)
for d in subdirs: for d in subdirs:
path = os.path.join(dirpath, d) path = os.path.join(dirpath, d)
try: try:
dirs[path] = os.path.getmtime(path) dirs[path] = os.path.getmtime(path)
except FileNotFoundError: except FileNotFoundError:
print(f"Warning: Unable to access {path}. Skipping this path.") logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue continue
logging.debug("found {} files".format(len(result)))
return result, dirs return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
@ -248,8 +251,8 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
err = "**** ERROR: Saving image outside the output folder is not allowed." + \ err = "**** ERROR: Saving image outside the output folder is not allowed." + \
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \ "\n full_output_folder: " + os.path.abspath(full_output_folder) + \
"\n output_dir: " + output_dir + \ "\n output_dir: " + output_dir + \
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
print(err) logging.error(err)
raise Exception(err) raise Exception(err)
try: try:

1
main.py

@ -139,6 +139,7 @@ def prompt_worker(q, server):
if need_gc: if need_gc:
current_time = time.perf_counter() current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval: if (current_time - last_gc_collect) > gc_collect_interval:
comfy.model_management.cleanup_models()
gc.collect() gc.collect()
comfy.model_management.soft_empty_cache() comfy.model_management.soft_empty_cache()
last_gc_collect = current_time last_gc_collect = current_time

10
node_helpers.py

@ -0,0 +1,10 @@
def conditioning_set_values(conditioning, values={}):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
n[1][k] = values[k]
c.append(n)
return c

62
nodes.py

@ -34,6 +34,7 @@ import importlib
import folder_paths import folder_paths
import latent_preview import latent_preview
import node_helpers
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
@ -41,7 +42,7 @@ def before_node_execution():
def interrupt_processing(value=True): def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value) comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192 MAX_RESOLUTION=16384
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
@ -151,13 +152,9 @@ class ConditioningSetArea:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength): def append(self, conditioning, width, height, x, y, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
for t in conditioning: "strength": strength,
n = [t[0], t[1].copy()] "set_area_to_bounds": False})
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
c.append(n)
return (c, ) return (c, )
class ConditioningSetAreaPercentage: class ConditioningSetAreaPercentage:
@ -176,13 +173,9 @@ class ConditioningSetAreaPercentage:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength): def append(self, conditioning, width, height, x, y, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
for t in conditioning: "strength": strength,
n = [t[0], t[1].copy()] "set_area_to_bounds": False})
n[1]['area'] = ("percentage", height, width, y, x)
n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
c.append(n)
return (c, ) return (c, )
class ConditioningSetAreaStrength: class ConditioningSetAreaStrength:
@ -197,11 +190,7 @@ class ConditioningSetAreaStrength:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, strength): def append(self, conditioning, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
return (c, ) return (c, )
@ -219,19 +208,15 @@ class ConditioningSetMask:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, mask, set_cond_area, strength): def append(self, conditioning, mask, set_cond_area, strength):
c = []
set_area_to_bounds = False set_area_to_bounds = False
if set_cond_area != "default": if set_cond_area != "default":
set_area_to_bounds = True set_area_to_bounds = True
if len(mask.shape) < 3: if len(mask.shape) < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
for t in conditioning:
n = [t[0], t[1].copy()] c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
_, h, w = mask.shape "set_area_to_bounds": set_area_to_bounds,
n[1]['mask'] = mask "mask_strength": strength})
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['mask_strength'] = strength
c.append(n)
return (c, ) return (c, )
class ConditioningZeroOut: class ConditioningZeroOut:
@ -266,13 +251,8 @@ class ConditioningSetTimestepRange:
CATEGORY = "advanced/conditioning" CATEGORY = "advanced/conditioning"
def set_range(self, conditioning, start, end): def set_range(self, conditioning, start, end):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
for t in conditioning: "end_percent": end})
d = t[1].copy()
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
return (c, ) return (c, )
class VAEDecode: class VAEDecode:
@ -413,13 +393,8 @@ class InpaintModelConditioning:
out = [] out = []
for conditioning in [positive, negative]: for conditioning in [positive, negative]:
c = [] c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
for t in conditioning: "concat_mask": mask})
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c) out.append(c)
return (out[0], out[1], out_latent) return (out[0], out[1], out_latent)
@ -1876,6 +1851,7 @@ def load_custom_node(module_path, ignore=set()):
sp = os.path.splitext(module_path) sp = os.path.splitext(module_path)
module_name = sp[0] module_name = sp[0]
try: try:
logging.debug("Trying to load custom node {}".format(module_path))
if os.path.isfile(module_path): if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path) module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module_dir = os.path.split(module_path)[0] module_dir = os.path.split(module_path)[0]
@ -1964,6 +1940,8 @@ def init_custom_nodes():
"nodes_morphology.py", "nodes_morphology.py",
"nodes_stable_cascade.py", "nodes_stable_cascade.py",
"nodes_differential_diffusion.py", "nodes_differential_diffusion.py",
"nodes_ip2p.py",
"nodes_model_merging_model_specific.py",
] ]
import_failed = [] import_failed = []

159
script_examples/websockets_api_example_ws_images.py

@ -0,0 +1,159 @@
#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without
#them being saved to disk
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
server_address = "127.0.0.1:8188"
client_id = str(uuid.uuid4())
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)['prompt_id']
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['prompt_id'] == prompt_id:
if data['node'] is None:
break #Execution is done
else:
current_node = data['node']
else:
if current_node == 'save_image_websocket_node':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
return output_images
prompt_text = """
{
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": 8,
"denoise": 1,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "euler",
"scheduler": "normal",
"seed": 8566257,
"steps": 20
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": 512,
"width": 512
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "masterpiece best quality girl"
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "bad hands"
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"save_image_websocket_node": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": [
"8",
0
]
}
}
}
"""
prompt = json.loads(prompt_text)
#set the text prompt for our positive CLIPTextEncode
prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
#set the seed for our KSampler node
prompt["3"]["inputs"]["seed"] = 5
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = get_images(ws, prompt)
#Commented out code to display the output images:
# for node_id in images:
# for image_data in images[node_id]:
# from PIL import Image
# import io
# image = Image.open(io.BytesIO(image_data))
# image.show()

2
tests-ui/tests/groupNode.test.js

@ -947,7 +947,7 @@ describe("group node", () => {
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max expect(p2.widgets.value.widget.options?.max).toBe(16384); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p1.widgets.value.value).toBe(128); expect(p1.widgets.value.value).toBe(128);

4
web/extensions/core/colorPalette.js

@ -20,6 +20,10 @@ const colorPalettes = {
"MODEL": "#B39DDB", // light lavender-purple "MODEL": "#B39DDB", // light lavender-purple
"STYLE_MODEL": "#C2FFAE", // light green-yellow "STYLE_MODEL": "#C2FFAE", // light green-yellow
"VAE": "#FF6E6E", // bright red "VAE": "#FF6E6E", // bright red
"NOISE": "#B0B0B0", // gray
"GUIDER": "#66FFFF", // cyan
"SAMPLER": "#ECB4B4", // very soft red
"SIGMAS": "#CDFFCD", // soft lime green
"TAESD": "#DCC274", // cheesecake "TAESD": "#DCC274", // cheesecake
}, },
"litegraph_base": { "litegraph_base": {

2
web/lib/litegraph.core.js

@ -7247,7 +7247,7 @@ LGraphNode.prototype.executeAction = function(action)
//create links //create links
for (var i = 0; i < clipboard_info.links.length; ++i) { for (var i = 0; i < clipboard_info.links.length; ++i) {
var link_info = clipboard_info.links[i]; var link_info = clipboard_info.links[i];
var origin_node; var origin_node = undefined;
var origin_node_relative_id = link_info[0]; var origin_node_relative_id = link_info[0];
if (origin_node_relative_id != null) { if (origin_node_relative_id != null) {
origin_node = nodes[origin_node_relative_id]; origin_node = nodes[origin_node_relative_id];

14
web/scripts/pnginfo.js

@ -170,9 +170,12 @@ export async function importA1111(graph, parameters) {
const opts = parameters const opts = parameters
.substr(p) .substr(p)
.split("\n")[1] .split("\n")[1]
.split(",") .match(new RegExp("\\s*([^:]+:\\s*([^\"\\{].*?|\".*?\"|\\{.*?\\}))\\s*(,|$)", "g"))
.reduce((p, n) => { .reduce((p, n) => {
const s = n.split(":"); const s = n.split(":");
if (s[1].endsWith(',')) {
s[1] = s[1].substr(0, s[1].length -1);
}
p[s[0].trim().toLowerCase()] = s[1].trim(); p[s[0].trim().toLowerCase()] = s[1].trim();
return p; return p;
}, {}); }, {});
@ -191,6 +194,7 @@ export async function importA1111(graph, parameters) {
const vaeLoaderNode = LiteGraph.createNode("VAELoader"); const vaeLoaderNode = LiteGraph.createNode("VAELoader");
const saveNode = LiteGraph.createNode("SaveImage"); const saveNode = LiteGraph.createNode("SaveImage");
let hrSamplerNode = null; let hrSamplerNode = null;
let hrSteps = null;
const ceil64 = (v) => Math.ceil(v / 64) * 64; const ceil64 = (v) => Math.ceil(v / 64) * 64;
@ -290,6 +294,9 @@ export async function importA1111(graph, parameters) {
model(v) { model(v) {
setWidgetValue(ckptNode, "ckpt_name", v, true); setWidgetValue(ckptNode, "ckpt_name", v, true);
}, },
"vae"(v) {
setWidgetValue(vaeLoaderNode, "vae_name", v, true);
},
"cfg scale"(v) { "cfg scale"(v) {
setWidgetValue(samplerNode, "cfg", +v); setWidgetValue(samplerNode, "cfg", +v);
}, },
@ -316,6 +323,7 @@ export async function importA1111(graph, parameters) {
const h = ceil64(+wxh[1]); const h = ceil64(+wxh[1]);
const hrUp = popOpt("hires upscale"); const hrUp = popOpt("hires upscale");
const hrSz = popOpt("hires resize"); const hrSz = popOpt("hires resize");
hrSteps = popOpt("hires steps");
let hrMethod = popOpt("hires upscaler"); let hrMethod = popOpt("hires upscaler");
setWidgetValue(imageNode, "width", w); setWidgetValue(imageNode, "width", w);
@ -398,7 +406,7 @@ export async function importA1111(graph, parameters) {
} }
if (hrSamplerNode) { if (hrSamplerNode) {
setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value); setWidgetValue(hrSamplerNode, "steps", hrSteps? +hrSteps : getWidget(samplerNode, "steps").value);
setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value); setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value);
setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value); setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value);
setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value); setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value);
@ -415,7 +423,7 @@ export async function importA1111(graph, parameters) {
graph.arrange(); graph.arrange();
for (const opt of ["model hash", "ensd"]) { for (const opt of ["model hash", "ensd", "version", "vae hash", "ti hashes", "lora hashes", "hashes"]) {
delete opts[opt]; delete opts[opt];
} }

Loading…
Cancel
Save