From 4c7a9dbcb66d3a53764d4725f92f7c116bcb4821 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 2 Apr 2023 18:44:27 -0400 Subject: [PATCH 1/3] adds Blend, Blur, Dither, Sharpen nodes --- comfy_extras/nodes_post_processing.py | 215 ++++++++++++++++++++++++++ nodes.py | 3 +- 2 files changed, 217 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_post_processing.py diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py new file mode 100644 index 00000000..3f3bddd7 --- /dev/null +++ b/comfy_extras/nodes_post_processing.py @@ -0,0 +1,215 @@ +import torch +import torch.nn.functional as F + + +class Blend: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image1": ("IMAGE",), + "image2": ("IMAGE",), + "blend_factor": ("FLOAT", { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01 + }), + "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "blend_images" + + CATEGORY = "postprocessing" + + def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + blended_image = self.blend_mode(image1, image2, blend_mode) + blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor + blended_image = torch.clamp(blended_image, 0, 1) + return (blended_image,) + + def blend_mode(self, img1, img2, mode): + if mode == "normal": + return img2 + elif mode == "multiply": + return img1 * img2 + elif mode == "screen": + return 1 - (1 - img1) * (1 - img2) + elif mode == "overlay": + return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) + elif mode == "soft_light": + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + else: + raise ValueError(f"Unsupported blend mode: {mode}") + + def g(self, x): + return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) + +class Blur: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "blur_radius": ("INT", { + "default": 1, + "min": 1, + "max": 31, + "step": 1 + }), + "sigma": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "blur" + + CATEGORY = "postprocessing" + + def gaussian_kernel(self, kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + + def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): + if blur_radius == 0: + return (image,) + + batch_size, height, width, channels = image.shape + + kernel_size = blur_radius * 2 + 1 + kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + + image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + blurred = blurred.permute(0, 2, 3, 1) + + return (blurred,) + +class Dither: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "bits": ("INT", { + "default": 4, + "min": 1, + "max": 8, + "step": 1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "dither" + + CATEGORY = "postprocessing" + + def dither(self, image: torch.Tensor, bits: int): + batch_size, height, width, _ = image.shape + result = torch.zeros_like(image) + + for b in range(batch_size): + tensor_image = image[b] + img = (tensor_image * 255) + height, width, _ = img.shape + + scale = 255 / (2**bits - 1) + + for y in range(height): + for x in range(width): + old_pixel = img[y, x].clone() + new_pixel = torch.round(old_pixel / scale) * scale + img[y, x] = new_pixel + + quant_error = old_pixel - new_pixel + + if x + 1 < width: + img[y, x + 1] += quant_error * 7 / 16 + if y + 1 < height: + if x - 1 >= 0: + img[y + 1, x - 1] += quant_error * 3 / 16 + img[y + 1, x] += quant_error * 5 / 16 + if x + 1 < width: + img[y + 1, x + 1] += quant_error * 1 / 16 + + dithered = img / 255 + tensor = dithered.unsqueeze(0) + result[b] = tensor + + return (result,) + +class Sharpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "sharpen_radius": ("INT", { + "default": 1, + "min": 1, + "max": 31, + "step": 1 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 5.0, + "step": 0.1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "sharpen" + + CATEGORY = "postprocessing" + + def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): + if sharpen_radius == 0: + return (image,) + + batch_size, height, width, channels = image.shape + + kernel_size = sharpen_radius * 2 + 1 + kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + center = kernel_size // 2 + kernel[center, center] = kernel_size**2 + kernel *= alpha + kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) + + tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + sharpened = sharpened.permute(0, 2, 3, 1) + + result = torch.clamp(sharpened, 0, 1) + + return (result,) + +NODE_CLASS_MAPPINGS = { + "Blend": Blend, + "Blur": Blur, + "Dither": Dither, + "Sharpen": Sharpen, +} diff --git a/nodes.py b/nodes.py index 963ff32a..a93f0410 100644 --- a/nodes.py +++ b/nodes.py @@ -1112,4 +1112,5 @@ def load_custom_nodes(): def init_custom_nodes(): load_custom_nodes() - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) \ No newline at end of file + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) From fa2febc0624678362cc758d316bb59afce9c8f06 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 3 Apr 2023 09:52:04 -0400 Subject: [PATCH 2/3] blend supports any size, dither -> quantize --- comfy_extras/nodes_post_processing.py | 74 ++++++++++++++++----------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 3f3bddd7..322f3ca8 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,5 +1,7 @@ +import numpy as np import torch import torch.nn.functional as F +from PIL import Image class Blend: @@ -28,6 +30,9 @@ class Blend: CATEGORY = "postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + if image1.shape != image2.shape: + image2 = self.crop_and_resize(image2, image1.shape) + blended_image = self.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor blended_image = torch.clamp(blended_image, 0, 1) @@ -50,6 +55,29 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) + def crop_and_resize(self, img: torch.Tensor, target_shape: tuple): + batch_size, img_h, img_w, img_c = img.shape + _, target_h, target_w, _ = target_shape + img_aspect_ratio = img_w / img_h + target_aspect_ratio = target_w / target_h + + # Crop center of the image to the target aspect ratio + if img_aspect_ratio > target_aspect_ratio: + new_width = int(img_h * target_aspect_ratio) + left = (img_w - new_width) // 2 + img = img[:, :, left:left + new_width, :] + else: + new_height = int(img_w / target_aspect_ratio) + top = (img_h - new_height) // 2 + img = img[:, top:top + new_height, :, :] + + # Resize to target size + img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False) + img = img.permute(0, 2, 3, 1) + + return img + class Blur: def __init__(self): pass @@ -100,7 +128,7 @@ class Blur: return (blurred,) -class Dither: +class Quantize: def __init__(self): pass @@ -109,51 +137,37 @@ class Dither: return { "required": { "image": ("IMAGE",), - "bits": ("INT", { - "default": 4, + "colors": ("INT", { + "default": 256, "min": 1, - "max": 8, + "max": 256, "step": 1 }), + "dither": (["none", "floyd-steinberg"],), }, } RETURN_TYPES = ("IMAGE",) - FUNCTION = "dither" + FUNCTION = "quantize" CATEGORY = "postprocessing" - def dither(self, image: torch.Tensor, bits: int): + def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): batch_size, height, width, _ = image.shape result = torch.zeros_like(image) + dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE + for b in range(batch_size): tensor_image = image[b] - img = (tensor_image * 255) - height, width, _ = img.shape - - scale = 255 / (2**bits - 1) - - for y in range(height): - for x in range(width): - old_pixel = img[y, x].clone() - new_pixel = torch.round(old_pixel / scale) * scale - img[y, x] = new_pixel - - quant_error = old_pixel - new_pixel + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') - if x + 1 < width: - img[y, x + 1] += quant_error * 7 / 16 - if y + 1 < height: - if x - 1 >= 0: - img[y + 1, x - 1] += quant_error * 3 / 16 - img[y + 1, x] += quant_error * 5 / 16 - if x + 1 < width: - img[y + 1, x + 1] += quant_error * 1 / 16 + palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 + quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) - dithered = img / 255 - tensor = dithered.unsqueeze(0) - result[b] = tensor + quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 + result[b] = quantized_array return (result,) @@ -210,6 +224,6 @@ class Sharpen: NODE_CLASS_MAPPINGS = { "Blend": Blend, "Blur": Blur, - "Dither": Dither, + "Quantize": Quantize, "Sharpen": Sharpen, } From 56196ab0f72c8f671bd85b425744f80f02c823ea Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 4 Apr 2023 10:57:34 -0400 Subject: [PATCH 3/3] use common_upcale in blend --- comfy_extras/nodes_post_processing.py | 29 +++++---------------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 322f3ca8..703deaab 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -3,6 +3,8 @@ import torch import torch.nn.functional as F from PIL import Image +import comfy.utils + class Blend: def __init__(self): @@ -31,7 +33,9 @@ class Blend: def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): if image1.shape != image2.shape: - image2 = self.crop_and_resize(image2, image1.shape) + image2 = image2.permute(0, 3, 1, 2) + image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') + image2 = image2.permute(0, 2, 3, 1) blended_image = self.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor @@ -55,29 +59,6 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) - def crop_and_resize(self, img: torch.Tensor, target_shape: tuple): - batch_size, img_h, img_w, img_c = img.shape - _, target_h, target_w, _ = target_shape - img_aspect_ratio = img_w / img_h - target_aspect_ratio = target_w / target_h - - # Crop center of the image to the target aspect ratio - if img_aspect_ratio > target_aspect_ratio: - new_width = int(img_h * target_aspect_ratio) - left = (img_w - new_width) // 2 - img = img[:, :, left:left + new_width, :] - else: - new_height = int(img_w / target_aspect_ratio) - top = (img_h - new_height) // 2 - img = img[:, top:top + new_height, :, :] - - # Resize to target size - img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False) - img = img.permute(0, 2, 3, 1) - - return img - class Blur: def __init__(self): pass