|
|
|
@ -190,6 +190,7 @@ class VAE:
|
|
|
|
|
offload_device = model_management.vae_offload_device() |
|
|
|
|
self.vae_dtype = model_management.vae_dtype() |
|
|
|
|
self.first_stage_model.to(self.vae_dtype) |
|
|
|
|
self.output_device = model_management.intermediate_device() |
|
|
|
|
|
|
|
|
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) |
|
|
|
|
|
|
|
|
@ -201,9 +202,9 @@ class VAE:
|
|
|
|
|
|
|
|
|
|
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() |
|
|
|
|
output = torch.clamp(( |
|
|
|
|
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + |
|
|
|
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + |
|
|
|
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) |
|
|
|
|
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + |
|
|
|
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + |
|
|
|
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar)) |
|
|
|
|
/ 3.0) / 2.0, min=0.0, max=1.0) |
|
|
|
|
return output |
|
|
|
|
|
|
|
|
@ -214,9 +215,9 @@ class VAE:
|
|
|
|
|
pbar = comfy.utils.ProgressBar(steps) |
|
|
|
|
|
|
|
|
|
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() |
|
|
|
|
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) |
|
|
|
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) |
|
|
|
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) |
|
|
|
|
samples /= 3.0 |
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
@ -228,15 +229,15 @@ class VAE:
|
|
|
|
|
batch_number = int(free_memory / memory_used) |
|
|
|
|
batch_number = max(1, batch_number) |
|
|
|
|
|
|
|
|
|
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") |
|
|
|
|
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device) |
|
|
|
|
for x in range(0, samples_in.shape[0], batch_number): |
|
|
|
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) |
|
|
|
|
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
|
|
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
|
|
except model_management.OOM_EXCEPTION as e: |
|
|
|
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") |
|
|
|
|
pixel_samples = self.decode_tiled_(samples_in) |
|
|
|
|
|
|
|
|
|
pixel_samples = pixel_samples.cpu().movedim(1,-1) |
|
|
|
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) |
|
|
|
|
return pixel_samples |
|
|
|
|
|
|
|
|
|
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): |
|
|
|
@ -252,10 +253,10 @@ class VAE:
|
|
|
|
|
free_memory = model_management.get_free_memory(self.device) |
|
|
|
|
batch_number = int(free_memory / memory_used) |
|
|
|
|
batch_number = max(1, batch_number) |
|
|
|
|
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") |
|
|
|
|
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device) |
|
|
|
|
for x in range(0, pixel_samples.shape[0], batch_number): |
|
|
|
|
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) |
|
|
|
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() |
|
|
|
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() |
|
|
|
|
|
|
|
|
|
except model_management.OOM_EXCEPTION as e: |
|
|
|
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") |
|
|
|
|