From e214c917ae889b278a05fa6e8b8c42d2cc8818fa Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Tue, 25 Apr 2023 00:15:25 -0700 Subject: [PATCH] Add Condition by Mask node This PR adds support for a Condition by Mask node. This node allows conditioning to be limited to a non-rectangle area. --- comfy/samplers.py | 90 ++++++++++++++++++++++++++++++++++++++--------- nodes.py | 28 +++++++++++++++ 2 files changed, 102 insertions(+), 16 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index fc19ddcf..6fa754b9 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,7 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -23,21 +24,34 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con adm_cond = cond[1]['adm_encoded'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - mult = torch.ones_like(input_x) * strength - - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + if 'mask' in cond[1]: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask = cond[1]['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if mask.shape[0] != input_x.shape[0]: + mask = mask.repeat(input_x.shape[0], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in cond[1]: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + conditionning = {} conditionning['c_crossattn'] = cond[0] if cond_concat_in is not None and len(cond_concat_in) > 0: @@ -301,6 +315,47 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def resolve_cond_masks(conditions, h, w, device): + # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. + # While we're doing this, we can also resolve the mask device and scaling for performance reasons + for i in range(len(conditions)): + c = conditions[i] + if 'mask' in c[1]: + mask = c[1]['mask'] + mask = mask.to(device=device) + modified = c[1].copy() + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if mask.shape[2] != h or mask.shape[3] != w: + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) + + if 'area' not in modified: + bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) + if torch.max(bounds) == 0: + # Handle the edge-case of an all black mask (where masks_to_boxes would error) + area = (0, 0, 0, 0) + else: + box = masks_to_boxes(bounds)[0].type(torch.int) + H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) + # Make sure the height and width are divisible by 8 + if X % 8 != 0: + newx = X // 8 * 8 + W = W + (X - newx) + X = newx + if Y % 8 != 0: + newy = Y // 8 * 8 + H = H + (Y - newy) + Y = newy + if H % 8 != 0: + H = H + (8 - (H % 8)) + if W % 8 != 0: + W = W + (8 - (W % 8)) + area = (int(H), int(W), int(Y), (X)) + modified['area'] = area + + modified['mask'] = mask + conditions[i] = [c[0], modified] + def create_cond_with_same_area_if_none(conds, c): if 'area' not in c[1]: return @@ -461,7 +516,6 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): if sigmas is None: sigmas = self.sigmas @@ -484,6 +538,10 @@ class KSampler: positive = positive[:] negative = negative[:] + + resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) + resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) diff --git a/nodes.py b/nodes.py index 0a9513be..be02f467 100644 --- a/nodes.py +++ b/nodes.py @@ -85,6 +85,32 @@ class ConditioningSetArea: c.append(n) return (c, ) +class ConditioningSetMask: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "mask": ("MASK", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + c = [] + if len(mask.shape) < 3: + mask = mask.unsqueeze(0) + for t in conditioning: + n = [t[0], t[1].copy()] + _, h, w = mask.shape + n[1]['mask'] = mask + n[1]['strength'] = strength + n[1]['min_sigma'] = min_sigma + n[1]['max_sigma'] = max_sigma + c.append(n) + return (c, ) + class VAEDecode: def __init__(self, device="cpu"): self.device = device @@ -1115,6 +1141,7 @@ NODE_CLASS_MAPPINGS = { "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, + "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, @@ -1164,6 +1191,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CLIPSetLastLayer": "CLIP Set Last Layer", "ConditioningCombine": "Conditioning (Combine)", "ConditioningSetArea": "Conditioning (Set Area)", + "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",