From 6d8834f08f7b5243cba925665f4a5a7296142383 Mon Sep 17 00:00:00 2001 From: Dmytro Mishkin Date: Mon, 4 Mar 2024 18:50:28 +0100 Subject: [PATCH] Add Morphology nodes from kornia (#2781) * import kornia * Added morphology nodexs * Add kornia to requirements * fix choices * options, also move to postprocessors * fix placing and step --- comfy_extras/nodes_canny.py | 270 +------------------------------ comfy_extras/nodes_morphology.py | 49 ++++++ nodes.py | 1 + requirements.txt | 1 + 4 files changed, 52 insertions(+), 269 deletions(-) create mode 100644 comfy_extras/nodes_morphology.py diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 730dded5..8138b5f7 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -5,275 +5,7 @@ import torch import torch.nn.functional as F import comfy.model_management -def get_canny_nms_kernel(device=None, dtype=None): - """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" - return torch.tensor( - [ - [[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - ], - device=device, - dtype=dtype, - ) - - -def get_hysteresis_kernel(device=None, dtype=None): - """Utility function that returns the 3x3 kernels for the Canny hysteresis.""" - return torch.tensor( - [ - [[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - ], - device=device, - dtype=dtype, - ) - -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = torch.nn.functional.pad(img, padding, mode="reflect") - img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3]) - - return img - -def get_sobel_kernel2d(device=None, dtype=None): - kernel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=device, dtype=dtype) - kernel_y = kernel_x.transpose(0, 1) - return torch.stack([kernel_x, kernel_y]) - -def spatial_gradient(input, normalized: bool = True): - r"""Compute the first order image derivative in both x and y using a Sobel operator. - .. image:: _static/img/spatial_gradient.png - Args: - input: input image tensor with shape :math:`(B, C, H, W)`. - mode: derivatives modality, can be: `sobel` or `diff`. - order: the order of the derivatives. - normalized: whether the output is normalized. - Return: - the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. - .. note:: - See a working example `here `__. - Examples: - >>> input = torch.rand(1, 3, 4, 4) - >>> output = spatial_gradient(input) # 1x3x2x4x4 - >>> output.shape - torch.Size([1, 3, 2, 4, 4]) - """ - # KORNIA_CHECK_IS_TENSOR(input) - # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W']) - - # allocate kernel - kernel = get_sobel_kernel2d(device=input.device, dtype=input.dtype) - if normalized: - kernel = normalize_kernel2d(kernel) - - # prepare kernel - b, c, h, w = input.shape - tmp_kernel = kernel[:, None, ...] - - # Pad with "replicate for spatial dims, but with zeros for channel - spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] - out_channels: int = 2 - padded_inp = torch.nn.functional.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate') - out = F.conv2d(padded_inp, tmp_kernel, groups=1, padding=0, stride=1) - return out.reshape(b, c, out_channels, h, w) - -def rgb_to_grayscale(image, rgb_weights = None): - r"""Convert a RGB image to grayscale version of image. - - .. image:: _static/img/rgb_to_grayscale.png - - The image data is assumed to be in the range of (0, 1). - - Args: - image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`. - rgb_weights: Weights that will be applied on each channel (RGB). - The sum of the weights should add up to one. - Returns: - grayscale version of the image with shape :math:`(*,1,H,W)`. - - .. note:: - See a working example `here `__. - - Example: - >>> input = torch.rand(2, 3, 4, 5) - >>> gray = rgb_to_grayscale(input) # 2x1x4x5 - """ - - if len(image.shape) < 3 or image.shape[-3] != 3: - raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") - - if rgb_weights is None: - # 8 bit images - if image.dtype == torch.uint8: - rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) - # floating point images - elif image.dtype in (torch.float16, torch.float32, torch.float64): - rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) - else: - raise TypeError(f"Unknown data type: {image.dtype}") - else: - # is tensor that we make sure is in the same device/dtype - rgb_weights = rgb_weights.to(image) - - # unpack the color image channels with RGB order - r: Tensor = image[..., 0:1, :, :] - g: Tensor = image[..., 1:2, :, :] - b: Tensor = image[..., 2:3, :, :] - - w_r, w_g, w_b = rgb_weights.unbind() - return w_r * r + w_g * g + w_b * b - -def canny( - input, - low_threshold = 0.1, - high_threshold = 0.2, - kernel_size = 5, - sigma = 1, - hysteresis = True, - eps = 1e-6, -): - r"""Find edges of the input image and filters them using the Canny algorithm. - .. image:: _static/img/canny.png - Args: - input: input image tensor with shape :math:`(B,C,H,W)`. - low_threshold: lower threshold for the hysteresis procedure. - high_threshold: upper threshold for the hysteresis procedure. - kernel_size: the size of the kernel for the gaussian blur. - sigma: the standard deviation of the kernel for the gaussian blur. - hysteresis: if True, applies the hysteresis edge tracking. - Otherwise, the edges are divided between weak (0.5) and strong (1) edges. - eps: regularization number to avoid NaN during backprop. - Returns: - - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. - - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. - .. note:: - See a working example `here `__. - Example: - >>> input = torch.rand(5, 3, 4, 4) - >>> magnitude, edges = canny(input) # 5x3x4x4 - >>> magnitude.shape - torch.Size([5, 1, 4, 4]) - >>> edges.shape - torch.Size([5, 1, 4, 4]) - """ - # KORNIA_CHECK_IS_TENSOR(input) - # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W']) - # KORNIA_CHECK( - # low_threshold <= high_threshold, - # "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: " - # f"{low_threshold}>{high_threshold}", - # ) - # KORNIA_CHECK(0 < low_threshold < 1, f'Invalid low threshold. Should be in range (0, 1). Got: {low_threshold}') - # KORNIA_CHECK(0 < high_threshold < 1, f'Invalid high threshold. Should be in range (0, 1). Got: {high_threshold}') - - device = input.device - dtype = input.dtype - - # To Grayscale - if input.shape[1] == 3: - input = rgb_to_grayscale(input) - - # Gaussian filter - blurred: Tensor = gaussian_blur_2d(input, kernel_size, sigma) - - # Compute the gradients - gradients: Tensor = spatial_gradient(blurred, normalized=False) - - # Unpack the edges - gx: Tensor = gradients[:, :, 0] - gy: Tensor = gradients[:, :, 1] - - # Compute gradient magnitude and angle - magnitude: Tensor = torch.sqrt(gx * gx + gy * gy + eps) - angle: Tensor = torch.atan2(gy, gx) - - # Radians to Degrees - angle = 180.0 * angle / math.pi - - # Round angle to the nearest 45 degree - angle = torch.round(angle / 45) * 45 - - # Non-maximal suppression - nms_kernels: Tensor = get_canny_nms_kernel(device, dtype) - nms_magnitude: Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) - - # Get the indices for both directions - positive_idx: Tensor = (angle / 45) % 8 - positive_idx = positive_idx.long() - - negative_idx: Tensor = ((angle / 45) + 4) % 8 - negative_idx = negative_idx.long() - - # Apply the non-maximum suppression to the different directions - channel_select_filtered_positive: Tensor = torch.gather(nms_magnitude, 1, positive_idx) - channel_select_filtered_negative: Tensor = torch.gather(nms_magnitude, 1, negative_idx) - - channel_select_filtered: Tensor = torch.stack( - [channel_select_filtered_positive, channel_select_filtered_negative], 1 - ) - - is_max: Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 - - magnitude = magnitude * is_max - - # Threshold - edges: Tensor = F.threshold(magnitude, low_threshold, 0.0) - - low: Tensor = magnitude > low_threshold - high: Tensor = magnitude > high_threshold - - edges = low * 0.5 + high * 0.5 - edges = edges.to(dtype) - - # Hysteresis - if hysteresis: - edges_old: Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) - hysteresis_kernels: Tensor = get_hysteresis_kernel(device, dtype) - - while ((edges_old - edges).abs() != 0).any(): - weak: Tensor = (edges == 0.5).float() - strong: Tensor = (edges == 1).float() - - hysteresis_magnitude: Tensor = F.conv2d( - edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 - ) - hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) - hysteresis_magnitude = hysteresis_magnitude * weak + strong - - edges_old = edges.clone() - edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 - - edges = hysteresis_magnitude - - return magnitude, edges +from kornia.filters import canny class Canny: diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py new file mode 100644 index 00000000..071521d8 --- /dev/null +++ b/comfy_extras/nodes_morphology.py @@ -0,0 +1,49 @@ +import torch +import comfy.model_management + +from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat + + +class Morphology: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), + "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + + CATEGORY = "image/postprocessing" + + def process(self, image, operation, kernel_size): + device = comfy.model_management.get_torch_device() + kernel = torch.ones(kernel_size, kernel_size, device=device) + image_k = image.to(device).movedim(-1, 1) + if operation == "erode": + output = erosion(image_k, kernel) + elif operation == "dilate": + output = dilation(image_k, kernel) + elif operation == "open": + output = opening(image_k, kernel) + elif operation == "close": + output = closing(image_k, kernel) + elif operation == "gradient": + output = gradient(image_k, kernel) + elif operation == "top_hat": + output = top_hat(image_k, kernel) + elif operation == "bottom_hat": + output = bottom_hat(image_k, kernel) + else: + raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") + img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) + return (img_out,) + +NODE_CLASS_MAPPINGS = { + "Morphology": Morphology, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Morphology": "ImageMorphology", +} \ No newline at end of file diff --git a/nodes.py b/nodes.py index 2cfea1a7..d8a4e822 100644 --- a/nodes.py +++ b/nodes.py @@ -1960,6 +1960,7 @@ def init_custom_nodes(): "nodes_sdupscale.py", "nodes_photomaker.py", "nodes_cond.py", + "nodes_morphology.py", "nodes_stable_cascade.py", "nodes_differential_diffusion.py", ] diff --git a/requirements.txt b/requirements.txt index e804618e..e7d8c0e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ Pillow scipy tqdm psutil +kornia>=0.7.1