Browse Source

Some fixes to the batch masks PR.

pull/570/head
comfyanonymous 2 years ago
parent
commit
aa57136dae
  1. 7
      comfy/sample.py
  2. 10
      nodes.py

7
comfy/sample.py

@ -1,7 +1,7 @@
import torch
import comfy.model_management
import comfy.samplers
import math
def prepare_noise(latent_image, seed, skip=0):
"""
@ -16,10 +16,11 @@ def prepare_noise(latent_image, seed, skip=0):
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * shape[0])
if noise_mask.shape[0] < shape[0]:
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
noise_mask = noise_mask.to(device)
return noise_mask

10
nodes.py

@ -172,16 +172,12 @@ class VAEEncodeForInpaint:
def encode(self, vae, pixels, mask):
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
if len(mask.shape) < 3:
mask = mask.unsqueeze(0).unsqueeze(0)
elif len(mask.shape) < 4:
mask = mask.unsqueeze(1)
mask = torch.nn.functional.interpolate(mask, size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
mask = mask[:,:x,:y,:]
mask = mask[:,:,:x,:y]
#grow mask by a few pixels to keep things seamless in latent space
kernel_tensor = torch.ones((1, 1, 6, 6))
@ -193,7 +189,7 @@ class VAEEncodeForInpaint:
pixels[:,:,:,i] += 0.5
t = vae.encode(pixels)
return ({"samples":t, "noise_mask": (mask_erosion[:,:x,:y,:].round())}, )
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class CheckpointLoader:
@classmethod

Loading…
Cancel
Save