|
|
|
@ -151,7 +151,7 @@ class CLIP:
|
|
|
|
|
return self.patcher.get_key_patches() |
|
|
|
|
|
|
|
|
|
class VAE: |
|
|
|
|
def __init__(self, sd=None, device=None, config=None): |
|
|
|
|
def __init__(self, sd=None, device=None, config=None, dtype=None): |
|
|
|
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format |
|
|
|
|
sd = diffusers_convert.convert_vae_state_dict(sd) |
|
|
|
|
|
|
|
|
@ -188,7 +188,9 @@ class VAE:
|
|
|
|
|
device = model_management.vae_device() |
|
|
|
|
self.device = device |
|
|
|
|
offload_device = model_management.vae_offload_device() |
|
|
|
|
self.vae_dtype = model_management.vae_dtype() |
|
|
|
|
if dtype is None: |
|
|
|
|
dtype = model_management.vae_dtype() |
|
|
|
|
self.vae_dtype = dtype |
|
|
|
|
self.first_stage_model.to(self.vae_dtype) |
|
|
|
|
self.output_device = model_management.intermediate_device() |
|
|
|
|
|
|
|
|
|