From 22edd3add541cdb956b50a3c6412ce2bb36c1090 Mon Sep 17 00:00:00 2001 From: shawnington <88048838+shawnington@users.noreply.github.com> Date: Sun, 12 May 2024 04:07:38 -0700 Subject: [PATCH] =?UTF-8?q?Fix=20to=20LoadImage=20Node=20for=20#3416=20HDR?= =?UTF-8?q?=20images=20loading=20additional=20smaller=E2=80=A6=20(#3454)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix to LoadImage Node for #3416 HDR images loading additional smaller images. Added a blocking if statement in the ImageSequence.Iterator that checks if subsequent images after the first match dimensionally, and prevent them from being appended to output_images if they do not match. This does not fix or change current behavior for PIL 10.2.0 where the images are loaded at the same size, but it does for 10.3.0 where they are loaded at their correct smaller sizes. * added list of excluded formats that should return 1 image added an explicit check for the image format so that additional formats can be added to the list that have problematic behavior. --- nodes.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 488afd57..37e7d734 100644 --- a/nodes.py +++ b/nodes.py @@ -1461,12 +1461,24 @@ class LoadImage: output_images = [] output_masks = [] + w, h = None, None + + excluded_formats = ['MPO'] + for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") + + if len(output_images) == 0: + w = image.size[0] + h = image.size[1] + + if image.size[0] != w or image.size[1] != h: + continue + image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] if 'A' in i.getbands(): @@ -1477,7 +1489,7 @@ class LoadImage: output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if len(output_images) > 1: + if len(output_images) > 1 and img.format not in excluded_formats: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: