|
|
|
@ -50,9 +50,17 @@ class TAESD(nn.Module):
|
|
|
|
|
self.encoder = Encoder() |
|
|
|
|
self.decoder = Decoder() |
|
|
|
|
if encoder_path is not None: |
|
|
|
|
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) |
|
|
|
|
if encoder_path.lower().endswith(".safetensors"): |
|
|
|
|
import safetensors.torch |
|
|
|
|
self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu")) |
|
|
|
|
else: |
|
|
|
|
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) |
|
|
|
|
if decoder_path is not None: |
|
|
|
|
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) |
|
|
|
|
if decoder_path.lower().endswith(".safetensors"): |
|
|
|
|
import safetensors.torch |
|
|
|
|
self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu")) |
|
|
|
|
else: |
|
|
|
|
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def scale_latents(x): |
|
|
|
|