|
|
|
@ -605,8 +605,9 @@ class VAE:
|
|
|
|
|
self.first_stage_model.load_state_dict(sd, strict=False) |
|
|
|
|
|
|
|
|
|
if device is None: |
|
|
|
|
device = model_management.get_torch_device() |
|
|
|
|
device = model_management.vae_device() |
|
|
|
|
self.device = device |
|
|
|
|
self.offload_device = model_management.vae_offload_device() |
|
|
|
|
|
|
|
|
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): |
|
|
|
|
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) |
|
|
|
@ -651,7 +652,7 @@ class VAE:
|
|
|
|
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") |
|
|
|
|
pixel_samples = self.decode_tiled_(samples_in) |
|
|
|
|
|
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu() |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.offload_device) |
|
|
|
|
pixel_samples = pixel_samples.cpu().movedim(1,-1) |
|
|
|
|
return pixel_samples |
|
|
|
|
|
|
|
|
@ -659,7 +660,7 @@ class VAE:
|
|
|
|
|
model_management.unload_model() |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
|
|
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) |
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu() |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.offload_device) |
|
|
|
|
return output.movedim(1,-1) |
|
|
|
|
|
|
|
|
|
def encode(self, pixel_samples): |
|
|
|
@ -679,7 +680,7 @@ class VAE:
|
|
|
|
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") |
|
|
|
|
samples = self.encode_tiled_(pixel_samples) |
|
|
|
|
|
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu() |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.offload_device) |
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): |
|
|
|
@ -687,7 +688,7 @@ class VAE:
|
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
|
|
pixel_samples = pixel_samples.movedim(-1,1) |
|
|
|
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) |
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu() |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.offload_device) |
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
|
def get_sd(self): |
|
|
|
|