|
|
|
@ -6,6 +6,7 @@ import contextlib
|
|
|
|
|
from comfy import model_management |
|
|
|
|
from .ldm.models.diffusion.ddim import DDIMSampler |
|
|
|
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps |
|
|
|
|
from torchvision.ops import masks_to_boxes |
|
|
|
|
|
|
|
|
|
#The main sampling function shared by all the samplers |
|
|
|
|
#Returns predicted noise |
|
|
|
@ -23,21 +24,34 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|
|
|
|
adm_cond = cond[1]['adm_encoded'] |
|
|
|
|
|
|
|
|
|
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] |
|
|
|
|
mult = torch.ones_like(input_x) * strength |
|
|
|
|
|
|
|
|
|
rr = 8 |
|
|
|
|
if area[2] != 0: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
if (area[0] + area[2]) < x_in.shape[2]: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
if area[3] != 0: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
if (area[1] + area[3]) < x_in.shape[3]: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
if 'mask' in cond[1]: |
|
|
|
|
# Scale the mask to the size of the input |
|
|
|
|
# The mask should have been resized as we began the sampling process |
|
|
|
|
mask = cond[1]['mask'] |
|
|
|
|
assert(mask.shape[1] == x_in.shape[2]) |
|
|
|
|
assert(mask.shape[2] == x_in.shape[3]) |
|
|
|
|
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] |
|
|
|
|
if mask.shape[0] != input_x.shape[0]: |
|
|
|
|
mask = mask.repeat(input_x.shape[0], 1, 1) |
|
|
|
|
else: |
|
|
|
|
mask = torch.ones_like(input_x) |
|
|
|
|
mult = mask * strength |
|
|
|
|
|
|
|
|
|
if 'mask' not in cond[1]: |
|
|
|
|
rr = 8 |
|
|
|
|
if area[2] != 0: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
if (area[0] + area[2]) < x_in.shape[2]: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
if area[3] != 0: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
if (area[1] + area[3]) < x_in.shape[3]: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
|
|
|
|
|
conditionning = {} |
|
|
|
|
conditionning['c_crossattn'] = cond[0] |
|
|
|
|
if cond_concat_in is not None and len(cond_concat_in) > 0: |
|
|
|
@ -301,6 +315,47 @@ def blank_inpaint_image_like(latent_image):
|
|
|
|
|
blank_image[:,3] *= 0.1380 |
|
|
|
|
return blank_image |
|
|
|
|
|
|
|
|
|
def resolve_cond_masks(conditions, h, w, device): |
|
|
|
|
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. |
|
|
|
|
# While we're doing this, we can also resolve the mask device and scaling for performance reasons |
|
|
|
|
for i in range(len(conditions)): |
|
|
|
|
c = conditions[i] |
|
|
|
|
if 'mask' in c[1]: |
|
|
|
|
mask = c[1]['mask'] |
|
|
|
|
mask = mask.to(device=device) |
|
|
|
|
modified = c[1].copy() |
|
|
|
|
if len(mask.shape) == 2: |
|
|
|
|
mask = mask.unsqueeze(0) |
|
|
|
|
if mask.shape[2] != h or mask.shape[3] != w: |
|
|
|
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) |
|
|
|
|
|
|
|
|
|
if 'area' not in modified: |
|
|
|
|
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) |
|
|
|
|
if torch.max(bounds) == 0: |
|
|
|
|
# Handle the edge-case of an all black mask (where masks_to_boxes would error) |
|
|
|
|
area = (0, 0, 0, 0) |
|
|
|
|
else: |
|
|
|
|
box = masks_to_boxes(bounds)[0].type(torch.int) |
|
|
|
|
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) |
|
|
|
|
# Make sure the height and width are divisible by 8 |
|
|
|
|
if X % 8 != 0: |
|
|
|
|
newx = X // 8 * 8 |
|
|
|
|
W = W + (X - newx) |
|
|
|
|
X = newx |
|
|
|
|
if Y % 8 != 0: |
|
|
|
|
newy = Y // 8 * 8 |
|
|
|
|
H = H + (Y - newy) |
|
|
|
|
Y = newy |
|
|
|
|
if H % 8 != 0: |
|
|
|
|
H = H + (8 - (H % 8)) |
|
|
|
|
if W % 8 != 0: |
|
|
|
|
W = W + (8 - (W % 8)) |
|
|
|
|
area = (int(H), int(W), int(Y), (X)) |
|
|
|
|
modified['area'] = area |
|
|
|
|
|
|
|
|
|
modified['mask'] = mask |
|
|
|
|
conditions[i] = [c[0], modified] |
|
|
|
|
|
|
|
|
|
def create_cond_with_same_area_if_none(conds, c): |
|
|
|
|
if 'area' not in c[1]: |
|
|
|
|
return |
|
|
|
@ -461,7 +516,6 @@ class KSampler:
|
|
|
|
|
sigmas = self.calculate_sigmas(new_steps).to(self.device) |
|
|
|
|
self.sigmas = sigmas[-(steps + 1):] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): |
|
|
|
|
if sigmas is None: |
|
|
|
|
sigmas = self.sigmas |
|
|
|
@ -484,6 +538,10 @@ class KSampler:
|
|
|
|
|
|
|
|
|
|
positive = positive[:] |
|
|
|
|
negative = negative[:] |
|
|
|
|
|
|
|
|
|
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) |
|
|
|
|
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) |
|
|
|
|
|
|
|
|
|
#make sure each cond area has an opposite one with the same area |
|
|
|
|
for c in positive: |
|
|
|
|
create_cond_with_same_area_if_none(negative, c) |
|
|
|
|