diff --git a/comfy/samplers.py b/comfy/samplers.py index 3aaf8ac4..34df116c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -375,7 +375,7 @@ def resolve_cond_masks(conditions, h, w, device): modified = c[1].copy() if len(mask.shape) == 2: mask = mask.unsqueeze(0) - if mask.shape[2] != h or mask.shape[3] != w: + if mask.shape[1] != h or mask.shape[2] != w: mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) if modified.get("set_area_to_bounds", False):