diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index a138b292..51bdb24f 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -2,6 +2,7 @@ import numpy as np import torch import torch.nn.functional as F from PIL import Image +import math import comfy.utils @@ -209,9 +210,36 @@ class Sharpen: return (result,) +class ImageScaleToTotalPixels: + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + crop_methods = ["disabled", "center"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, image, upscale_method, megapixels): + samples = image.movedim(-1,1) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1,-1) + return (s,) + NODE_CLASS_MAPPINGS = { "ImageBlend": Blend, "ImageBlur": Blur, "ImageQuantize": Quantize, "ImageSharpen": Sharpen, + "ImageScaleToTotalPixels": ImageScaleToTotalPixels, }