diff --git a/nodes.py b/nodes.py index aaa4f87a..c190d633 100644 --- a/nodes.py +++ b/nodes.py @@ -98,7 +98,7 @@ class VAEDecode: CATEGORY = "latent" def decode(self, vae, samples): - return (vae.decode(samples), ) + return (vae.decode(samples["samples"]), ) class VAEEncode: def __init__(self, device="cpu"): @@ -117,7 +117,9 @@ class VAEEncode: y = (pixels.shape[2] // 64) * 64 if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] - return (vae.encode(pixels), ) + t = vae.encode(pixels[:,:,:,:3]) + + return ({"samples":t}, ) class CheckpointLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") @@ -212,7 +214,7 @@ class EmptyLatentImage: def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (latent, ) + return ({"samples":latent}, ) def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": @@ -247,7 +249,8 @@ class LatentUpscale: CATEGORY = "latent" def upscale(self, samples, upscale_method, width, height, crop): - s = common_upscale(samples, width // 8, height // 8, upscale_method, crop) + s = samples.copy() + s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) class LatentRotate: @@ -262,6 +265,7 @@ class LatentRotate: CATEGORY = "latent" def rotate(self, samples, rotation): + s = samples.copy() rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -270,7 +274,7 @@ class LatentRotate: elif rotation.startswith("270"): rotate_by = 3 - s = torch.rot90(samples, k=rotate_by, dims=[3, 2]) + s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2]) return (s,) class LatentFlip: @@ -285,12 +289,11 @@ class LatentFlip: CATEGORY = "latent" def flip(self, samples, flip_method): + s = samples.copy() if flip_method.startswith("x"): - s = torch.flip(samples, dims=[2]) + s["samples"] = torch.flip(samples["samples"], dims=[2]) elif flip_method.startswith("y"): - s = torch.flip(samples, dims=[3]) - else: - s = samples + s["samples"] = torch.flip(samples["samples"], dims=[3]) return (s,) @@ -312,12 +315,15 @@ class LatentComposite: x = x // 8 y = y // 8 feather = feather // 8 - s = samples_to.clone() + samples_out = samples_to.copy() + s = samples_to["samples"].clone() + samples_to = samples_to["samples"] + samples_from = samples_from["samples"] if feather == 0: s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] else: - s_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] - mask = torch.ones_like(s_from) + samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + mask = torch.ones_like(samples_from) for t in range(feather): if y != 0: mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) @@ -330,7 +336,8 @@ class LatentComposite: mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) rev_mask = torch.ones_like(mask) - mask s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask - return (s,) + samples_out["samples"] = s + return (samples_out,) class LatentCrop: @classmethod @@ -347,6 +354,8 @@ class LatentCrop: CATEGORY = "latent" def crop(self, samples, width, height, x, y): + s = samples.copy() + samples = samples['samples'] x = x // 8 y = y // 8 @@ -370,15 +379,46 @@ class LatentCrop: #make sure size is always multiple of 64 x, to_x = enforce_image_dim(x, to_x, samples.shape[3]) y, to_y = enforce_image_dim(y, to_y, samples.shape[2]) - s = samples[:,:,y:to_y, x:to_x] + s['samples'] = samples[:,:,y:to_y, x:to_x] return (s,) -def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): +class SetLatentNoiseMask: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "mask": ("MASK",), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "set_mask" + + CATEGORY = "latent" + + def set_mask(self, samples, mask): + s = samples.copy() + s["noise_mask"] = mask + return (s,) + + +def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): + latent_image = latent["samples"] + noise_mask = None + if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") + if "noise_mask" in latent: + noise_mask = latent['noise_mask'] + print(noise_mask.shape, noise.shape) + + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = noise_mask.floor() + noise_mask = torch.ones_like(noise_mask) - noise_mask + noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = noise_mask.to(device) + real_model = None if device != "cpu": model_management.load_model_gpu(model) @@ -411,10 +451,11 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po #other samplers pass - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() - - return (samples, ) + out = latent.copy() + out["samples"] = samples + return (out, ) class KSampler: def __init__(self, device="cuda"): @@ -589,6 +630,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, "KSamplerAdvanced": KSamplerAdvanced, + "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, "LatentRotate": LatentRotate, "LatentFlip": LatentFlip,