From a3a713b6c581f4c0487c58c5a20eca2a5e8e6bde Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Jun 2023 01:26:52 -0400 Subject: [PATCH] Refactor previews into one command line argument. Clean up a few things. --- README.md | 4 +- comfy/cli_args.py | 5 +- comfy/taesd/taesd.py | 4 +- comfy/utils.py | 3 - folder_paths.py | 2 +- latent_preview.py | 95 +++++++++++++++++++ ...esd_encoder_pth_and_taesd_decoder_pth_here | 0 nodes.py | 94 +----------------- 8 files changed, 107 insertions(+), 100 deletions(-) create mode 100644 latent_preview.py rename models/{taesd => vae_approx}/put_taesd_encoder_pth_and_taesd_decoder_pth_here (100%) diff --git a/README.md b/README.md index 6e0803ab..d998afe6 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,9 @@ You can set this command line setting to disable the upcasting to fp32 in some c ## How to show high-quality previews? -The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/taesd` folder. Once they're installed, restart ComfyUI to enable high-quality previews. +Use ```--preview-method auto``` to enable previews. + +The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. ## Support and dev channel diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 3e6b1daa..b56497de 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -45,11 +45,12 @@ parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If th parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") class LatentPreviewMethod(enum.Enum): + NoPreviews = "none" Auto = "auto" Latent2RGB = "latent2rgb" TAESD = "taesd" -parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.") -parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.") + +parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index e6406745..1549345a 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -50,9 +50,9 @@ class TAESD(nn.Module): self.encoder = Encoder() self.decoder = Decoder() if encoder_path is not None: - self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu")) + self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) if decoder_path is not None: - self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu")) + self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) @staticmethod def scale_latents(x): diff --git a/comfy/utils.py b/comfy/utils.py index 08944ade..291c62e4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,7 +1,6 @@ import torch import math import struct -import comfy.model_management def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -167,8 +166,6 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): - comfy.model_management.throw_exception_if_processing_interrupted() - s_in = s[:,:,y:y+tile_y,x:x+tile_x] ps = function(s_in).cpu() diff --git a/folder_paths.py b/folder_paths.py index 38729928..2ad1b171 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -18,7 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision" folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) -folder_names_and_paths["taesd"] = ([os.path.join(models_dir, "taesd")], supported_pt_extensions) +folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) diff --git a/latent_preview.py b/latent_preview.py new file mode 100644 index 00000000..ef6c201b --- /dev/null +++ b/latent_preview.py @@ -0,0 +1,95 @@ +import torch +from PIL import Image, ImageOps +from io import BytesIO +import struct +import numpy as np + +from comfy.cli_args import args, LatentPreviewMethod +from comfy.taesd.taesd import TAESD +import folder_paths + +MAX_PREVIEW_RESOLUTION = 512 + +class LatentPreviewer: + def decode_latent_to_preview(self, x0): + pass + + def decode_latent_to_preview_image(self, preview_format, x0): + preview_image = self.decode_latent_to_preview(x0) + preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS) + + preview_type = 1 + if preview_format == "JPEG": + preview_type = 1 + elif preview_format == "PNG": + preview_type = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", preview_type) + bytesIO.write(header) + preview_image.save(bytesIO, format=preview_format, quality=95) + preview_bytes = bytesIO.getvalue() + return preview_bytes + +class TAESDPreviewerImpl(LatentPreviewer): + def __init__(self, taesd): + self.taesd = taesd + + def decode_latent_to_preview(self, x0): + x_sample = self.taesd.decoder(x0)[0].detach() + # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] + x_sample = x_sample.sub(0.5).mul(2) + + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + + preview_image = Image.fromarray(x_sample) + return preview_image + + +class Latent2RGBPreviewer(LatentPreviewer): + def __init__(self): + self.latent_rgb_factors = torch.tensor([ + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ], device="cpu") + + def decode_latent_to_preview(self, x0): + latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors + + latents_ubyte = (((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte()).cpu() + + return Image.fromarray(latents_ubyte.numpy()) + + +def get_previewer(device): + previewer = None + method = args.preview_method + if method != LatentPreviewMethod.NoPreviews: + # TODO previewer methods + taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth") + + if method == LatentPreviewMethod.Auto: + method = LatentPreviewMethod.Latent2RGB + if taesd_decoder_path: + method = LatentPreviewMethod.TAESD + + if method == LatentPreviewMethod.TAESD: + if taesd_decoder_path: + taesd = TAESD(None, taesd_decoder_path).to(device) + previewer = TAESDPreviewerImpl(taesd) + else: + print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth") + + if previewer is None: + previewer = Latent2RGBPreviewer() + return previewer + + diff --git a/models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here b/models/vae_approx/put_taesd_encoder_pth_and_taesd_decoder_pth_here similarity index 100% rename from models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here rename to models/vae_approx/put_taesd_encoder_pth_and_taesd_decoder_pth_here diff --git a/nodes.py b/nodes.py index 971b5c3b..b057504e 100644 --- a/nodes.py +++ b/nodes.py @@ -7,15 +7,12 @@ import hashlib import traceback import math import time -import struct -from io import BytesIO from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch - sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -24,8 +21,6 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils -from comfy.cli_args import args, LatentPreviewMethod -from comfy.taesd.taesd import TAESD import comfy.clip_vision @@ -33,33 +28,7 @@ import comfy.model_management import importlib import folder_paths - - -class LatentPreviewer: - def decode_latent_to_preview(self, device, x0): - pass - - -class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self): - self.latent_rgb_factors = torch.tensor([ - # R G B - [0.298, 0.207, 0.208], # L1 - [0.187, 0.286, 0.173], # L2 - [-0.158, 0.189, 0.264], # L3 - [-0.184, -0.271, -0.473], # L4 - ], device="cpu") - - def decode_latent_to_preview(self, device, x0): - latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors - - latents_ubyte = (((latent_image + 1) / 2) - .clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - .byte()).cpu() - - return Image.fromarray(latents_ubyte.numpy()) - +import latent_preview def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -68,7 +37,6 @@ def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) MAX_RESOLUTION=8192 -MAX_PREVIEW_RESOLUTION = 512 class CLIPTextEncode: @classmethod @@ -279,22 +247,6 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) -class TAESDPreviewerImpl(LatentPreviewer): - def __init__(self, taesd): - self.taesd = taesd - - def decode_latent_to_preview(self, device, x0): - x_sample = self.taesd.decoder(x0.to(device))[0].detach() - # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] - x_sample = x_sample.sub(0.5).mul(2) - - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - - preview_image = Image.fromarray(x_sample) - return preview_image - class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -978,25 +930,6 @@ class SetLatentNoiseMask: return (s,) -def decode_latent_to_preview_image(previewer, device, preview_format, x0): - preview_image = previewer.decode_latent_to_preview(device, x0) - preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS) - - preview_type = 1 - if preview_format == "JPEG": - preview_type = 1 - elif preview_format == "PNG": - preview_type = 2 - - bytesIO = BytesIO() - header = struct.pack(">I", preview_type) - bytesIO.write(header) - preview_image.save(bytesIO, format=preview_format) - preview_bytes = bytesIO.getvalue() - - return preview_bytes - - def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -1015,34 +948,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" - previewer = None - if not args.disable_previews: - # TODO previewer methods - taesd_encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") - taesd_decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") - - method = args.default_preview_method - - if method == LatentPreviewMethod.Auto: - method = LatentPreviewMethod.Latent2RGB - if taesd_encoder_path and taesd_encoder_path: - method = LatentPreviewMethod.TAESD - - if method == LatentPreviewMethod.TAESD: - if taesd_encoder_path and taesd_encoder_path: - taesd = TAESD(taesd_encoder_path, taesd_decoder_path).to(device) - previewer = TAESDPreviewerImpl(taesd) - else: - print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth") - - if previewer is None: - previewer = Latent2RGBPreviewer() + previewer = latent_preview.get_previewer(device) pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): preview_bytes = None if previewer: - preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0) + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) pbar.update_absolute(step + 1, total_steps, preview_bytes) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,