diff --git a/nodes.py b/nodes.py index 488afd57..37e7d734 100644 --- a/nodes.py +++ b/nodes.py @@ -1461,12 +1461,24 @@ class LoadImage: output_images = [] output_masks = [] + w, h = None, None + + excluded_formats = ['MPO'] + for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") + + if len(output_images) == 0: + w = image.size[0] + h = image.size[1] + + if image.size[0] != w or image.size[1] != h: + continue + image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] if 'A' in i.getbands(): @@ -1477,7 +1489,7 @@ class LoadImage: output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if len(output_images) > 1: + if len(output_images) > 1 and img.format not in excluded_formats: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: