diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py new file mode 100644 index 00000000..c4c58b64 --- /dev/null +++ b/comfy_extras/nodes_compositing.py @@ -0,0 +1,239 @@ +import numpy as np +import torch +import comfy.utils +from enum import Enum + + +class PorterDuffMode(Enum): + ADD = 0 + CLEAR = 1 + DARKEN = 2 + DST = 3 + DST_ATOP = 4 + DST_IN = 5 + DST_OUT = 6 + DST_OVER = 7 + LIGHTEN = 8 + MULTIPLY = 9 + OVERLAY = 10 + SCREEN = 11 + SRC = 12 + SRC_ATOP = 13 + SRC_IN = 14 + SRC_OUT = 15 + SRC_OVER = 16 + XOR = 17 + + +def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): + if mode == PorterDuffMode.ADD: + out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1) + out_image = torch.clamp(src_image + dst_image, 0, 1) + elif mode == PorterDuffMode.CLEAR: + out_alpha = torch.zeros_like(dst_alpha) + out_image = torch.zeros_like(dst_image) + elif mode == PorterDuffMode.DARKEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image) + elif mode == PorterDuffMode.DST: + out_alpha = dst_alpha + out_image = dst_image + elif mode == PorterDuffMode.DST_ATOP: + out_alpha = src_alpha + out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.DST_IN: + out_alpha = src_alpha * dst_alpha + out_image = dst_image * src_alpha + elif mode == PorterDuffMode.DST_OUT: + out_alpha = (1 - src_alpha) * dst_alpha + out_image = (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.DST_OVER: + out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha + out_image = dst_image + (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.LIGHTEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image) + elif mode == PorterDuffMode.MULTIPLY: + out_alpha = src_alpha * dst_alpha + out_image = src_image * dst_image + elif mode == PorterDuffMode.OVERLAY: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image, + src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image)) + elif mode == PorterDuffMode.SCREEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = src_image + dst_image - src_image * dst_image + elif mode == PorterDuffMode.SRC: + out_alpha = src_alpha + out_image = src_image + elif mode == PorterDuffMode.SRC_ATOP: + out_alpha = dst_alpha + out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.SRC_IN: + out_alpha = src_alpha * dst_alpha + out_image = src_image * dst_alpha + elif mode == PorterDuffMode.SRC_OUT: + out_alpha = (1 - dst_alpha) * src_alpha + out_image = (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.SRC_OVER: + out_alpha = src_alpha + (1 - src_alpha) * dst_alpha + out_image = src_image + (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.XOR: + out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + else: + out_alpha = None + out_image = None + return out_image, out_alpha + + +class PorterDuffImageComposite: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source": ("IMAGE",), + "source_alpha": ("ALPHA",), + "destination": ("IMAGE",), + "destination_alpha": ("ALPHA",), + "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), + }, + } + + RETURN_TYPES = ("IMAGE", "ALPHA") + FUNCTION = "composite" + CATEGORY = "compositing" + + def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): + batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) + out_images = [] + out_alphas = [] + + for i in range(batch_size): + src_image = source[i] + dst_image = destination[i] + + src_alpha = source_alpha[i].unsqueeze(2) + dst_alpha = destination_alpha[i].unsqueeze(2) + + if dst_alpha.shape != dst_image.shape: + upscale_input = dst_alpha[None,:,:,:].permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') + dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) + if src_image.shape != dst_image.shape: + upscale_input = src_image[None,:,:,:].permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') + src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0) + if src_alpha.shape != dst_alpha.shape: + upscale_input = src_alpha[None,:,:,:].permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center') + src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) + + out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode]) + + out_images.append(out_image) + out_alphas.append(out_alpha.squeeze(2)) + + result = (torch.stack(out_images), torch.stack(out_alphas)) + return result + + +class SplitImageWithAlpha: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + } + } + + CATEGORY = "compositing" + RETURN_TYPES = ("IMAGE", "ALPHA") + FUNCTION = "split_image_with_alpha" + + def split_image_with_alpha(self, image: torch.Tensor): + out_images = [i[:,:,:3] for i in image] + out_alphas = [i[:,:,3] for i in image] + result = (torch.stack(out_images), torch.stack(out_alphas)) + return result + + +class JoinImageWithAlpha: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "alpha": ("ALPHA",), + } + } + + CATEGORY = "compositing" + RETURN_TYPES = ("IMAGE",) + FUNCTION = "join_image_with_alpha" + + def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor): + batch_size = min(len(image), len(alpha)) + out_images = [] + + for i in range(batch_size): + out_images.append(torch.cat((image[i], alpha[i].unsqueeze(2)), dim=2)) + + result = (torch.stack(out_images),) + return result + + +class ConvertAlphaToImage: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "alpha": ("ALPHA",), + } + } + + CATEGORY = "compositing" + RETURN_TYPES = ("IMAGE",) + FUNCTION = "alpha_to_image" + + def alpha_to_image(self, alpha): + result = alpha.reshape((-1, 1, alpha.shape[-2], alpha.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + return (result,) + + +class ConvertImageToAlpha: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": (["red", "green", "blue", "alpha"],), + } + } + + CATEGORY = "compositing" + RETURN_TYPES = ("ALPHA",) + FUNCTION = "image_to_alpha" + + def image_to_alpha(self, image, channel): + channels = ["red", "green", "blue", "alpha"] + alpha = image[0, :, :, channels.index(channel)] + return (alpha,) + + +NODE_CLASS_MAPPINGS = { + "PorterDuffImageComposite": PorterDuffImageComposite, + "SplitImageWithAlpha": SplitImageWithAlpha, + "JoinImageWithAlpha": JoinImageWithAlpha, + "ConvertAlphaToImage": ConvertAlphaToImage, + "ConvertImageToAlpha": ConvertImageToAlpha, +} + + +NODE_DISPLAY_NAME_MAPPINGS = { + "PorterDuffImageComposite": "Porter-Duff Image Composite", + "SplitImageWithAlpha": "Split Image with Alpha", + "JoinImageWithAlpha": "Join Image with Alpha", + "ConvertAlphaToImage": "Convert Alpha to Image", + "ConvertImageToAlpha": "Convert Image to Alpha", +} diff --git a/nodes.py b/nodes.py index 919aac89..8be332f9 100644 --- a/nodes.py +++ b/nodes.py @@ -1372,6 +1372,31 @@ class LoadImage: return True +class LoadImageWithAlpha(LoadImage): + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + return {"required": + {"image": (sorted(files), {"image_upload": True})}, + } + + CATEGORY = "compositing" + + RETURN_TYPES = ("IMAGE", "ALPHA") + + FUNCTION = "load_image" + def load_image(self, image): + image_path = folder_paths.get_annotated_filepath(image) + i = Image.open(image_path) + i = ImageOps.exif_transpose(i) + image = i.convert("RGBA") + alpha = np.array(image.getchannel("A")).astype(np.float32) / 255.0 + alpha = torch.from_numpy(alpha)[None,] + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + return (image, alpha) + class LoadImageMask: _color_channels = ["alpha", "red", "green", "blue"] @classmethod @@ -1606,6 +1631,7 @@ NODE_CLASS_MAPPINGS = { "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, + "LoadImageWithAlpha": LoadImageWithAlpha, "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, "ImageScaleBy": ImageScaleBy, @@ -1702,6 +1728,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveImage": "Save Image", "PreviewImage": "Preview Image", "LoadImage": "Load Image", + "LoadImageWithAlpha": "Load Image with Alpha", "LoadImageMask": "Load Image (as Mask)", "ImageScale": "Upscale Image", "ImageScaleBy": "Upscale Image By", @@ -1788,6 +1815,7 @@ def init_custom_nodes(): "nodes_upscale_model.py", "nodes_post_processing.py", "nodes_mask.py", + "nodes_compositing.py", "nodes_rebatch.py", "nodes_model_merging.py", "nodes_tomesd.py",