From d9b1595f8552384dd08374d34c4d4127e0b1a4e6 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:53:10 +0200 Subject: [PATCH] made sample functions more explicit --- comfy/sample.py | 55 +++++++++++++++++++++---------------------------- nodes.py | 7 +++++-- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 981781b5..84eefcb7 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,30 +2,25 @@ import torch import comfy.model_management -def prepare_noise(latent, seed): - """creates random noise given a LATENT and a seed""" - latent_image = latent["samples"] - batch_index = 0 - if "batch_index" in latent: - batch_index = latent["batch_index"] - +def prepare_noise(latent_image, seed, skip=0): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ generator = torch.manual_seed(seed) - for i in range(batch_index): + for _ in range(skip): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return noise -def create_mask(latent, noise): - """creates a mask for a given LATENT and noise""" - noise_mask = None +def prepare_mask(noise_mask, noise): + """ensures noise mask is of proper dimensions""" device = comfy.model_management.get_torch_device() - if "noise_mask" in latent: - noise_mask = latent['noise_mask'] - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") - noise_mask = noise_mask.round() - noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) - noise_mask = noise_mask.to(device) + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = noise_mask.round() + noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = noise_mask.to(device) return noise_mask def broadcast_cond(cond, noise): @@ -40,22 +35,20 @@ def broadcast_cond(cond, noise): copy += [[t] + p[1:]] return copy -def load_c_nets(positive, negative): - """loads control nets in positive and negative conditioning""" - def get_models(cond): - models = [] - for c in cond: - if 'control' in c[1]: - models += [c[1]['control']] - if 'gligen' in c[1]: - models += [c[1]['gligen'][1]] - return models - - return get_models(positive) + get_models(negative) +def get_models_from_cond(cond, model_type): + models = [] + for c in cond: + if model_type in c[1]: + models += [c[1][model_type]] + return models def load_additional_models(positive, negative): """loads additional models in positive and negative conditioning""" - models = load_c_nets(positive, negative) + models = [] + models += get_models_from_cond(positive, "control") + models += get_models_from_cond(negative, "control") + models += get_models_from_cond(positive, "gligen") + models += get_models_from_cond(negative, "gligen") comfy.model_management.load_controlnet_gpu(models) return models diff --git a/nodes.py b/nodes.py index b8c6d350..f9bedc97 100644 --- a/nodes.py +++ b/nodes.py @@ -747,9 +747,12 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - noise = comfy.sample.prepare_noise(latent, seed) + skip = latent["batch_index"] if "batch_index" in latent else 0 + noise = comfy.sample.prepare_noise(latent_image, seed, skip) - noise_mask = comfy.sample.create_mask(latent, noise) + noise_mask = None + if "noise_mask" in latent: + noise_mask = comfy.sample.prepare_mask(latent["noise_mask"], noise) real_model = None comfy.model_management.load_model_gpu(model)