From 1ee7a52f3b996145df33f62c6cd48680966ffa55 Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Mon, 6 May 2024 12:25:54 +0300 Subject: [PATCH] Fix alpha in PorterDuffImageComposite. There were two bugs in PorterDuffImageComposite. The first one is the fact that it uses the mask input directly as alpha, missing the conversion (`1-a`). The fix is similar to c16f5744. The second one is that all color composition formulas assume alpha premultiplied values, while the input is not premultiplied. This change fixes both of these issue. --- comfy_extras/nodes_compositing.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 181b36ed..56d1ff77 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -28,6 +28,14 @@ class PorterDuffMode(Enum): def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): + # convert mask to alpha + src_alpha = 1 - src_alpha + dst_alpha = 1 - dst_alpha + # premultiply alpha + src_image = src_image * src_alpha + dst_image = dst_image * dst_alpha + + # composite ops below assume alpha-premultiplied images if mode == PorterDuffMode.ADD: out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1) out_image = torch.clamp(src_image + dst_image, 0, 1) @@ -35,7 +43,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ out_alpha = torch.zeros_like(dst_alpha) out_image = torch.zeros_like(dst_image) elif mode == PorterDuffMode.DARKEN: - out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image) elif mode == PorterDuffMode.DST: out_alpha = dst_alpha @@ -84,8 +92,13 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image else: - out_alpha = None - out_image = None + return None, None + + # back to non-premultiplied alpha + out_image = torch.where(out_alpha > 0, out_image / out_alpha, torch.zeros_like(out_image)) + out_image = torch.clamp(out_image, 0, 1) + # convert alpha to mask + out_alpha = 1 - out_alpha return out_image, out_alpha