Browse Source

Supports TAESD models in safetensors format

pull/1703/head
Yukimasa Funaoka 1 year ago
parent
commit
9eb621c95a
No known key found for this signature in database
GPG Key ID: E8F24A863BB84ACC
  1. 4
      comfy/latent_formats.py
  2. 12
      comfy/taesd/taesd.py
  3. 7
      latent_preview.py

4
comfy/latent_formats.py

@ -20,7 +20,7 @@ class SD15(LatentFormat):
[-0.2829, 0.1762, 0.2721],
[-0.2120, -0.2616, -0.7177]
]
self.taesd_decoder_name = "taesd_decoder.pth"
self.taesd_decoder_name = "taesd_decoder"
class SDXL(LatentFormat):
def __init__(self):
@ -32,4 +32,4 @@ class SDXL(LatentFormat):
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder.pth"
self.taesd_decoder_name = "taesdxl_decoder"

12
comfy/taesd/taesd.py

@ -50,9 +50,17 @@ class TAESD(nn.Module):
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
if encoder_path.lower().endswith(".safetensors"):
import safetensors.torch
self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu"))
else:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
if decoder_path.lower().endswith(".safetensors"):
import safetensors.torch
self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu"))
else:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
@staticmethod
def scale_latents(x):

7
latent_preview.py

@ -56,7 +56,12 @@ def get_previewer(device, latent_format):
# TODO previewer methods
taesd_decoder_path = None
if latent_format.taesd_decoder_name is not None:
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
taesd_decoder_path = next(
(fn for fn in folder_paths.get_filename_list("vae_approx")
if fn.startswith(latent_format.taesd_decoder_name)),
""
)
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB

Loading…
Cancel
Save