|
|
|
@ -171,24 +171,28 @@ class VAEEncodeForInpaint:
|
|
|
|
|
def encode(self, vae, pixels, mask): |
|
|
|
|
x = (pixels.shape[1] // 64) * 64 |
|
|
|
|
y = (pixels.shape[2] // 64) * 64 |
|
|
|
|
mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0] |
|
|
|
|
if len(mask.shape) < 3: |
|
|
|
|
mask = mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
elif len(mask.shape) < 4: |
|
|
|
|
mask = mask.unsqueeze(1) |
|
|
|
|
mask = torch.nn.functional.interpolate(mask, size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") |
|
|
|
|
|
|
|
|
|
pixels = pixels.clone() |
|
|
|
|
if pixels.shape[1] != x or pixels.shape[2] != y: |
|
|
|
|
pixels = pixels[:,:x,:y,:] |
|
|
|
|
mask = mask[:x,:y] |
|
|
|
|
mask = mask[:,:x,:y,:] |
|
|
|
|
|
|
|
|
|
#grow mask by a few pixels to keep things seamless in latent space |
|
|
|
|
kernel_tensor = torch.ones((1, 1, 6, 6)) |
|
|
|
|
mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1) |
|
|
|
|
m = (1.0 - mask.round()) |
|
|
|
|
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) |
|
|
|
|
m = (1.0 - mask.round()).squeeze(1) |
|
|
|
|
for i in range(3): |
|
|
|
|
pixels[:,:,:,i] -= 0.5 |
|
|
|
|
pixels[:,:,:,i] *= m |
|
|
|
|
pixels[:,:,:,i] += 0.5 |
|
|
|
|
t = vae.encode(pixels) |
|
|
|
|
|
|
|
|
|
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) |
|
|
|
|
return ({"samples":t, "noise_mask": (mask_erosion[:,:x,:y,:].round())}, ) |
|
|
|
|
|
|
|
|
|
class CheckpointLoader: |
|
|
|
|
@classmethod |
|
|
|
@ -759,10 +763,15 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|
|
|
|
|
|
|
|
|
if "noise_mask" in latent: |
|
|
|
|
noise_mask = latent['noise_mask'] |
|
|
|
|
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") |
|
|
|
|
if len(noise_mask.shape) < 3: |
|
|
|
|
noise_mask = noise_mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
elif len(noise_mask.shape) < 4: |
|
|
|
|
noise_mask = noise_mask.unsqueeze(1) |
|
|
|
|
noise_mask = torch.nn.functional.interpolate(noise_mask, size=(noise.shape[2], noise.shape[3]), mode="bilinear") |
|
|
|
|
noise_mask = noise_mask.round() |
|
|
|
|
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) |
|
|
|
|
noise_mask = torch.cat([noise_mask] * noise.shape[0]) |
|
|
|
|
if noise_mask.shape[0] < latent_image.shape[0]: |
|
|
|
|
noise_mask = noise_mask.repeat(latent_image.shape[0] // noise_mask.shape[0], 1, 1, 1) |
|
|
|
|
noise_mask = noise_mask.to(device) |
|
|
|
|
|
|
|
|
|
real_model = None |
|
|
|
|