Browse Source

Add support for TAESD decoder for SDXL.

pull/793/head
comfyanonymous 1 year ago
parent
commit
cef6aa62b2
  1. 2
      README.md
  2. 17
      comfy/latent_formats.py
  3. 18
      latent_preview.py
  4. 2
      nodes.py

2
README.md

@ -193,7 +193,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
## Support and dev channel

17
comfy/latent_formats.py

@ -9,8 +9,23 @@ class LatentFormat:
class SD15(LatentFormat):
def __init__(self, scale_factor=0.18215):
self.scale_factor = scale_factor
self.latent_rgb_factors = [
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
]
self.taesd_decoder_name = "taesd_decoder.pth"
class SDXL(LatentFormat):
def __init__(self):
self.scale_factor = 0.13025
self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
]
self.taesd_decoder_name = "taesdxl_decoder.pth"

18
latent_preview.py

@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self):
self.latent_rgb_factors = torch.tensor([
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], device="cpu")
def __init__(self, latent_rgb_factors):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
def decode_latent_to_preview(self, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
return Image.fromarray(latents_ubyte.numpy())
def get_previewer(device):
def get_previewer(device, latent_format):
previewer = None
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
@ -86,10 +80,10 @@ def get_previewer(device):
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None:
previewer = Latent2RGBPreviewer()
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
return previewer

2
nodes.py

@ -954,7 +954,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = latent_preview.get_previewer(device)
previewer = latent_preview.get_previewer(device, model.model.latent_format)
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):

Loading…
Cancel
Save