Browse Source

Merge branch 'condition_by_mask_node' of https://github.com/guill/ComfyUI

pull/594/head
comfyanonymous 2 years ago
parent
commit
870fae62e7
  1. 112
      comfy/samplers.py
  2. 30
      nodes.py

112
comfy/samplers.py

@ -23,21 +23,33 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
adm_cond = cond[1]['adm_encoded'] adm_cond = cond[1]['adm_encoded']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength if 'mask' in cond[1]:
# Scale the mask to the size of the input
rr = 8 # The mask should have been resized as we began the sampling process
if area[2] != 0: mask = cond[1]['mask']
for t in range(rr): assert(mask.shape[1] == x_in.shape[2])
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) assert(mask.shape[2] == x_in.shape[3])
if (area[0] + area[2]) < x_in.shape[2]: mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
for t in range(rr): mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) else:
if area[3] != 0: mask = torch.ones_like(input_x)
for t in range(rr): mult = mask * strength
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]: if 'mask' not in cond[1]:
for t in range(rr): rr = 8
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {} conditionning = {}
conditionning['c_crossattn'] = cond[0] conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0: if cond_concat_in is not None and len(cond_concat_in) > 0:
@ -301,6 +313,71 @@ def blank_inpaint_image_like(latent_image):
blank_image[:,3] *= 0.1380 blank_image[:,3] *= 0.1380
return blank_image return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
b = masks.shape[0]
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
for i in range(b):
mask = masks[i]
if mask.numel() == 0:
continue
if torch.max(mask != 0) == False:
is_empty[i] = True
continue
y, x = torch.where(mask)
bounding_boxes[i, 0] = torch.min(x)
bounding_boxes[i, 1] = torch.min(y)
bounding_boxes[i, 2] = torch.max(x)
bounding_boxes[i, 3] = torch.max(y)
return bounding_boxes, is_empty
def resolve_cond_masks(conditions, h, w, device):
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'mask' in c[1]:
mask = c[1]['mask']
mask = mask.to(device=device)
modified = c[1].copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[2] != h or mask.shape[3] != w:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
if modified.get("set_area_to_bounds", False):
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
boxes, is_empty = get_mask_aabb(bounds)
if is_empty[0]:
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
modified['area'] = (8, 8, 0, 0)
else:
box = boxes[0]
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
# Make sure the height and width are divisible by 8
if X % 8 != 0:
newx = X // 8 * 8
W = W + (X - newx)
X = newx
if Y % 8 != 0:
newy = Y // 8 * 8
H = H + (Y - newy)
Y = newy
if H % 8 != 0:
H = H + (8 - (H % 8))
if W % 8 != 0:
W = W + (8 - (W % 8))
area = (int(H), int(W), int(Y), int(X))
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
def create_cond_with_same_area_if_none(conds, c): def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]: if 'area' not in c[1]:
return return
@ -461,7 +538,6 @@ class KSampler:
sigmas = self.calculate_sigmas(new_steps).to(self.device) sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):] self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None):
if sigmas is None: if sigmas is None:
sigmas = self.sigmas sigmas = self.sigmas
@ -484,6 +560,10 @@ class KSampler:
positive = positive[:] positive = positive[:]
negative = negative[:] negative = negative[:]
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
create_cond_with_same_area_if_none(negative, c) create_cond_with_same_area_if_none(negative, c)

30
nodes.py

@ -85,6 +85,34 @@ class ConditioningSetArea:
c.append(n) c.append(n)
return (c, ) return (c, )
class ConditioningSetMask:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"mask": ("MASK", ),
"set_area_to_bounds": ([False, True],),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0):
c = []
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
for t in conditioning:
n = [t[0], t[1].copy()]
_, h, w = mask.shape
n[1]['mask'] = mask
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['strength'] = strength
n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma
c.append(n)
return (c, )
class VAEDecode: class VAEDecode:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -1115,6 +1143,7 @@ NODE_CLASS_MAPPINGS = {
"ImagePadForOutpaint": ImagePadForOutpaint, "ImagePadForOutpaint": ImagePadForOutpaint,
"ConditioningCombine": ConditioningCombine, "ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea, "ConditioningSetArea": ConditioningSetArea,
"ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced, "KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask, "SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite, "LatentComposite": LatentComposite,
@ -1164,6 +1193,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CLIPSetLastLayer": "CLIP Set Last Layer", "CLIPSetLastLayer": "CLIP Set Last Layer",
"ConditioningCombine": "Conditioning (Combine)", "ConditioningCombine": "Conditioning (Combine)",
"ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet", "ControlNetApply": "Apply ControlNet",
# Latent # Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",

Loading…
Cancel
Save