diff --git a/node_helpers.py b/node_helpers.py index 60f8fa41..43b9e829 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -1,4 +1,4 @@ -from PIL import Image, ImageFile, UnidentifiedImageError +from PIL import ImageFile, UnidentifiedImageError def conditioning_set_values(conditioning, values={}): c = [] @@ -10,16 +10,15 @@ def conditioning_set_values(conditioning, values={}): return c -def open_image(path): +def pillow(fn, arg): prev_value = None - try: - img = Image.open(path) - except (UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445 + x = fn(arg) + except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416 prev_value = ImageFile.LOAD_TRUNCATED_IMAGES ImageFile.LOAD_TRUNCATED_IMAGES = True - img = Image.open(path) + x = fn(arg) finally: if prev_value is not None: ImageFile.LOAD_TRUNCATED_IMAGES = prev_value - return img + return x diff --git a/nodes.py b/nodes.py index 4d3171b8..488afd57 100644 --- a/nodes.py +++ b/nodes.py @@ -1457,21 +1457,12 @@ class LoadImage: def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) - img = node_helpers.open_image(image_path) + img = node_helpers.pillow(Image.open, image_path) output_images = [] output_masks = [] for i in ImageSequence.Iterator(img): - prev_value = None - try: - i = ImageOps.exif_transpose(i) - except OSError: - prev_value = ImageFile.LOAD_TRUNCATED_IMAGES - ImageFile.LOAD_TRUNCATED_IMAGES = True - i = ImageOps.exif_transpose(i) - finally: - if prev_value is not None: - ImageFile.LOAD_TRUNCATED_IMAGES = prev_value + i = node_helpers.pillow(ImageOps.exif_transpose, i) if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) @@ -1527,8 +1518,8 @@ class LoadImageMask: FUNCTION = "load_image" def load_image(self, image, channel): image_path = folder_paths.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) + i = node_helpers.pillow(Image.open, image_path) + i = node_helpers.pillow(ImageOps.exif_transpose, i) if i.getbands() != ("R", "G", "B", "A"): if i.mode == 'I': i = i.point(lambda i: i * (1 / 255))