|
|
|
@ -321,7 +321,7 @@ class VAEEncodeForInpaint:
|
|
|
|
|
def encode(self, vae, pixels, mask, grow_mask_by=6): |
|
|
|
|
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio |
|
|
|
|
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio |
|
|
|
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") |
|
|
|
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear").float() |
|
|
|
|
|
|
|
|
|
pixels = pixels.clone() |
|
|
|
|
if pixels.shape[1] != x or pixels.shape[2] != y: |
|
|
|
|