From 6f7852bc47de2fa432672a1b93c1727c0824d78b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 17:24:58 -0400 Subject: [PATCH] Add a LatentFromBatch node to pick a single latent from a batch. Works before and after sampling. --- nodes.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index c775da00..c745ce28 100644 --- a/nodes.py +++ b/nodes.py @@ -510,6 +510,24 @@ class EmptyLatentImage: return ({"samples":latent}, ) +class LatentFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "rotate" + + CATEGORY = "latent" + + def rotate(self, samples, batch_index): + s = samples.copy() + s_in = samples["samples"] + batch_index = min(s_in.shape[0] - 1, batch_index) + s["samples"] = s_in[batch_index:batch_index + 1].clone() + s["batch_index"] = batch_index + return (s,) class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -685,7 +703,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") + batch_index = 0 + if "batch_index" in latent: + batch_index = latent["batch_index"] + + generator = torch.manual_seed(seed) + for i in range(batch_index + 1): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") if "noise_mask" in latent: noise_mask = latent['noise_mask'] @@ -1073,6 +1097,7 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, + "LatentFromBatch": LatentFromBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage,