|
|
|
@ -544,6 +544,19 @@ class VAE:
|
|
|
|
|
/ 3.0) / 2.0, min=0.0, max=1.0) |
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): |
|
|
|
|
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) |
|
|
|
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) |
|
|
|
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) |
|
|
|
|
pbar = utils.ProgressBar(steps) |
|
|
|
|
|
|
|
|
|
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor |
|
|
|
|
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples /= 3.0 |
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
|
def decode(self, samples_in): |
|
|
|
|
model_management.unload_model() |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
|
@ -574,28 +587,26 @@ class VAE:
|
|
|
|
|
def encode(self, pixel_samples): |
|
|
|
|
model_management.unload_model() |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
|
|
pixel_samples = pixel_samples.movedim(-1,1).to(self.device) |
|
|
|
|
samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor |
|
|
|
|
pixel_samples = pixel_samples.movedim(-1,1) |
|
|
|
|
try: |
|
|
|
|
batch_number = 1 |
|
|
|
|
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") |
|
|
|
|
for x in range(0, pixel_samples.shape[0], batch_number): |
|
|
|
|
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device) |
|
|
|
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor |
|
|
|
|
except model_management.OOM_EXCEPTION as e: |
|
|
|
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") |
|
|
|
|
samples = self.encode_tiled_(pixel_samples) |
|
|
|
|
|
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu() |
|
|
|
|
samples = samples.cpu() |
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): |
|
|
|
|
model_management.unload_model() |
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device) |
|
|
|
|
pixel_samples = pixel_samples.movedim(-1,1).to(self.device) |
|
|
|
|
|
|
|
|
|
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) |
|
|
|
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) |
|
|
|
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) |
|
|
|
|
pbar = utils.ProgressBar(steps) |
|
|
|
|
|
|
|
|
|
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) |
|
|
|
|
samples /= 3.0 |
|
|
|
|
pixel_samples = pixel_samples.movedim(-1,1) |
|
|
|
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) |
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu() |
|
|
|
|
samples = samples.cpu() |
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
|
def broadcast_image_to(tensor, target_batch_size, batched_number): |
|
|
|
|