From e214c917ae889b278a05fa6e8b8c42d2cc8818fa Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Tue, 25 Apr 2023 00:15:25 -0700 Subject: [PATCH 1/2] Add Condition by Mask node This PR adds support for a Condition by Mask node. This node allows conditioning to be limited to a non-rectangle area. --- comfy/samplers.py | 90 ++++++++++++++++++++++++++++++++++++++--------- nodes.py | 28 +++++++++++++++ 2 files changed, 102 insertions(+), 16 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index fc19ddcf..6fa754b9 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,7 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -23,21 +24,34 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con adm_cond = cond[1]['adm_encoded'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - mult = torch.ones_like(input_x) * strength - - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + if 'mask' in cond[1]: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask = cond[1]['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if mask.shape[0] != input_x.shape[0]: + mask = mask.repeat(input_x.shape[0], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in cond[1]: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + conditionning = {} conditionning['c_crossattn'] = cond[0] if cond_concat_in is not None and len(cond_concat_in) > 0: @@ -301,6 +315,47 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def resolve_cond_masks(conditions, h, w, device): + # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. + # While we're doing this, we can also resolve the mask device and scaling for performance reasons + for i in range(len(conditions)): + c = conditions[i] + if 'mask' in c[1]: + mask = c[1]['mask'] + mask = mask.to(device=device) + modified = c[1].copy() + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if mask.shape[2] != h or mask.shape[3] != w: + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) + + if 'area' not in modified: + bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) + if torch.max(bounds) == 0: + # Handle the edge-case of an all black mask (where masks_to_boxes would error) + area = (0, 0, 0, 0) + else: + box = masks_to_boxes(bounds)[0].type(torch.int) + H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) + # Make sure the height and width are divisible by 8 + if X % 8 != 0: + newx = X // 8 * 8 + W = W + (X - newx) + X = newx + if Y % 8 != 0: + newy = Y // 8 * 8 + H = H + (Y - newy) + Y = newy + if H % 8 != 0: + H = H + (8 - (H % 8)) + if W % 8 != 0: + W = W + (8 - (W % 8)) + area = (int(H), int(W), int(Y), (X)) + modified['area'] = area + + modified['mask'] = mask + conditions[i] = [c[0], modified] + def create_cond_with_same_area_if_none(conds, c): if 'area' not in c[1]: return @@ -461,7 +516,6 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): if sigmas is None: sigmas = self.sigmas @@ -484,6 +538,10 @@ class KSampler: positive = positive[:] negative = negative[:] + + resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) + resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) diff --git a/nodes.py b/nodes.py index 0a9513be..be02f467 100644 --- a/nodes.py +++ b/nodes.py @@ -85,6 +85,32 @@ class ConditioningSetArea: c.append(n) return (c, ) +class ConditioningSetMask: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "mask": ("MASK", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + c = [] + if len(mask.shape) < 3: + mask = mask.unsqueeze(0) + for t in conditioning: + n = [t[0], t[1].copy()] + _, h, w = mask.shape + n[1]['mask'] = mask + n[1]['strength'] = strength + n[1]['min_sigma'] = min_sigma + n[1]['max_sigma'] = max_sigma + c.append(n) + return (c, ) + class VAEDecode: def __init__(self, device="cpu"): self.device = device @@ -1115,6 +1141,7 @@ NODE_CLASS_MAPPINGS = { "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, + "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, @@ -1164,6 +1191,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CLIPSetLastLayer": "CLIP Set Last Layer", "ConditioningCombine": "Conditioning (Combine)", "ConditioningSetArea": "Conditioning (Set Area)", + "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", From af02393c2a7134861df57e5843fc17498c65a795 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 29 Apr 2023 00:16:58 -0700 Subject: [PATCH 2/2] Default to sampling entire image By default, when applying a mask to a condition, the entire image will still be used for sampling. The new "set_area_to_bounds" option on the node will allow the user to automatically limit conditioning to the bounds of the mask. I've also removed the dependency on torchvision for calculating bounding boxes. I've taken the opportunity to fix some frustrating details in the other version: 1. An all-0 mask will no longer cause an error 2. Indices are returned as integers instead of floats so they can be used to index into tensors. --- comfy/samplers.py | 42 ++++++++++++++++++++++++++++++++---------- nodes.py | 4 +++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6fa754b9..f8701c87 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,7 +6,6 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps -from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -31,8 +30,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if mask.shape[0] != input_x.shape[0]: - mask = mask.repeat(input_x.shape[0], 1, 1) + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) mult = mask * strength @@ -315,6 +313,29 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def get_mask_aabb(masks): + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device, dtype=torch.int) + + b = masks.shape[0] + + bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int) + is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool) + for i in range(b): + mask = masks[i] + if mask.numel() == 0: + continue + if torch.max(mask != 0) == False: + is_empty[i] = True + continue + y, x = torch.where(mask) + bounding_boxes[i, 0] = torch.min(x) + bounding_boxes[i, 1] = torch.min(y) + bounding_boxes[i, 2] = torch.max(x) + bounding_boxes[i, 3] = torch.max(y) + + return bounding_boxes, is_empty + def resolve_cond_masks(conditions, h, w, device): # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. # While we're doing this, we can also resolve the mask device and scaling for performance reasons @@ -329,13 +350,14 @@ def resolve_cond_masks(conditions, h, w, device): if mask.shape[2] != h or mask.shape[3] != w: mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) - if 'area' not in modified: + if modified.get("set_area_to_bounds", False): bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) - if torch.max(bounds) == 0: - # Handle the edge-case of an all black mask (where masks_to_boxes would error) - area = (0, 0, 0, 0) + boxes, is_empty = get_mask_aabb(bounds) + if is_empty[0]: + # Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway) + modified['area'] = (8, 8, 0, 0) else: - box = masks_to_boxes(bounds)[0].type(torch.int) + box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) # Make sure the height and width are divisible by 8 if X % 8 != 0: @@ -350,8 +372,8 @@ def resolve_cond_masks(conditions, h, w, device): H = H + (8 - (H % 8)) if W % 8 != 0: W = W + (8 - (W % 8)) - area = (int(H), int(W), int(Y), (X)) - modified['area'] = area + area = (int(H), int(W), int(Y), int(X)) + modified['area'] = area modified['mask'] = mask conditions[i] = [c[0], modified] diff --git a/nodes.py b/nodes.py index be02f467..12fa7e5a 100644 --- a/nodes.py +++ b/nodes.py @@ -90,6 +90,7 @@ class ConditioningSetMask: def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), "mask": ("MASK", ), + "set_area_to_bounds": ([False, True],), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -97,7 +98,7 @@ class ConditioningSetMask: CATEGORY = "conditioning" - def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0): c = [] if len(mask.shape) < 3: mask = mask.unsqueeze(0) @@ -105,6 +106,7 @@ class ConditioningSetMask: n = [t[0], t[1].copy()] _, h, w = mask.shape n[1]['mask'] = mask + n[1]['set_area_to_bounds'] = set_area_to_bounds n[1]['strength'] = strength n[1]['min_sigma'] = min_sigma n[1]['max_sigma'] = max_sigma