Browse Source

Split VAE decode batches depending on free memory.

pull/315/head
comfyanonymous 2 years ago
parent
commit
b2554bc4dd
  1. 11
      comfy/sd.py

11
comfy/sd.py

@ -439,9 +439,14 @@ class VAE:
model_management.unload_model() model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
try: try:
samples = samples_in.to(self.device) free_memory = model_management.get_free_memory(self.device)
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples) batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64))
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)

Loading…
Cancel
Save