|
|
@ -439,9 +439,14 @@ class VAE: |
|
|
|
model_management.unload_model() |
|
|
|
model_management.unload_model() |
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
|
try: |
|
|
|
try: |
|
|
|
samples = samples_in.to(self.device) |
|
|
|
free_memory = model_management.get_free_memory(self.device) |
|
|
|
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples) |
|
|
|
batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64)) |
|
|
|
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
|
batch_number = max(1, batch_number) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") |
|
|
|
|
|
|
|
for x in range(0, samples_in.shape[0], batch_number): |
|
|
|
|
|
|
|
samples = samples_in[x:x+batch_number].to(self.device) |
|
|
|
|
|
|
|
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu() |
|
|
|
except model_management.OOM_EXCEPTION as e: |
|
|
|
except model_management.OOM_EXCEPTION as e: |
|
|
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") |
|
|
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") |
|
|
|
pixel_samples = self.decode_tiled_(samples_in) |
|
|
|
pixel_samples = self.decode_tiled_(samples_in) |
|
|
|