|
|
|
@ -21,8 +21,8 @@ class CFGDenoiser(torch.nn.Module):
|
|
|
|
|
uncond = self.inner_model(x, sigma, cond=uncond) |
|
|
|
|
return uncond + (cond - uncond) * cond_scale |
|
|
|
|
|
|
|
|
|
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): |
|
|
|
|
def get_area_and_mult(cond, x_in): |
|
|
|
|
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None): |
|
|
|
|
def get_area_and_mult(cond, x_in, cond_concat_in): |
|
|
|
|
area = (x_in.shape[2], x_in.shape[3], 0, 0) |
|
|
|
|
strength = 1.0 |
|
|
|
|
min_sigma = 0.0 |
|
|
|
@ -48,9 +48,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|
|
|
|
if (area[1] + area[3]) < x_in.shape[3]: |
|
|
|
|
for t in range(rr): |
|
|
|
|
mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1)) |
|
|
|
|
return (input_x, mult, cond[0], area) |
|
|
|
|
conditionning = {} |
|
|
|
|
conditionning['c_crossattn'] = cond[0] |
|
|
|
|
if cond_concat_in is not None and len(cond_concat_in) > 0: |
|
|
|
|
cropped = [] |
|
|
|
|
for x in cond_concat_in: |
|
|
|
|
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] |
|
|
|
|
cropped.append(cr) |
|
|
|
|
conditionning['c_concat'] = torch.cat(cropped, dim=1) |
|
|
|
|
return (input_x, mult, conditionning, area) |
|
|
|
|
|
|
|
|
|
def cond_equal_size(c1, c2): |
|
|
|
|
if c1.keys() != c2.keys(): |
|
|
|
|
return False |
|
|
|
|
if 'c_crossattn' in c1: |
|
|
|
|
if c1['c_crossattn'].shape != c2['c_crossattn'].shape: |
|
|
|
|
return False |
|
|
|
|
if 'c_concat' in c1: |
|
|
|
|
if c1['c_concat'].shape != c2['c_concat'].shape: |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
def cond_cat(c_list): |
|
|
|
|
c_crossattn = [] |
|
|
|
|
c_concat = [] |
|
|
|
|
for x in c_list: |
|
|
|
|
if 'c_crossattn' in x: |
|
|
|
|
c_crossattn.append(x['c_crossattn']) |
|
|
|
|
if 'c_concat' in x: |
|
|
|
|
c_concat.append(x['c_concat']) |
|
|
|
|
out = {} |
|
|
|
|
if len(c_crossattn) > 0: |
|
|
|
|
out['c_crossattn'] = [torch.cat(c_crossattn)] |
|
|
|
|
if len(c_concat) > 0: |
|
|
|
|
out['c_concat'] = [torch.cat(c_concat)] |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area): |
|
|
|
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in): |
|
|
|
|
out_cond = torch.zeros_like(x_in) |
|
|
|
|
out_count = torch.ones_like(x_in)/100000.0 |
|
|
|
|
|
|
|
|
@ -62,13 +96,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|
|
|
|
|
|
|
|
|
to_run = [] |
|
|
|
|
for x in cond: |
|
|
|
|
p = get_area_and_mult(x, x_in) |
|
|
|
|
p = get_area_and_mult(x, x_in, cond_concat_in) |
|
|
|
|
if p is None: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
to_run += [(p, COND)] |
|
|
|
|
for x in uncond: |
|
|
|
|
p = get_area_and_mult(x, x_in) |
|
|
|
|
p = get_area_and_mult(x, x_in, cond_concat_in) |
|
|
|
|
if p is None: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
@ -80,7 +114,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|
|
|
|
to_batch_temp = [] |
|
|
|
|
for x in range(len(to_run)): |
|
|
|
|
if to_run[x][0][0].shape == first_shape: |
|
|
|
|
if to_run[x][0][2].shape == first[0][2].shape: |
|
|
|
|
if cond_equal_size(to_run[x][0][2], first[0][2]): |
|
|
|
|
to_batch_temp += [x] |
|
|
|
|
|
|
|
|
|
to_batch_temp.reverse() |
|
|
|
@ -108,7 +142,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|
|
|
|
|
|
|
|
|
batch_chunks = len(cond_or_uncond) |
|
|
|
|
input_x = torch.cat(input_x) |
|
|
|
|
c = torch.cat(c) |
|
|
|
|
c = cond_cat(c) |
|
|
|
|
sigma_ = torch.cat([sigma] * batch_chunks) |
|
|
|
|
|
|
|
|
|
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks) |
|
|
|
@ -132,18 +166,18 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_total_area = model_management.maximum_batch_area() |
|
|
|
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area) |
|
|
|
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat) |
|
|
|
|
return uncond + (cond - uncond) * cond_scale |
|
|
|
|
|
|
|
|
|
class CFGDenoiserComplex(torch.nn.Module): |
|
|
|
|
def __init__(self, model): |
|
|
|
|
super().__init__() |
|
|
|
|
self.inner_model = model |
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask): |
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): |
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
latent_mask = 1. - denoise_mask |
|
|
|
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask |
|
|
|
|
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) |
|
|
|
|
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat) |
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
out *= denoise_mask |
|
|
|
|
|
|
|
|
@ -159,6 +193,17 @@ def simple_scheduler(model, steps):
|
|
|
|
|
sigs += [0.0] |
|
|
|
|
return torch.FloatTensor(sigs) |
|
|
|
|
|
|
|
|
|
def blank_inpaint_image_like(latent_image): |
|
|
|
|
blank_image = torch.ones_like(latent_image) |
|
|
|
|
# these are the values for "zero" in pixel space translated to latent space |
|
|
|
|
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE |
|
|
|
|
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works |
|
|
|
|
blank_image[:,0] *= 0.8223 |
|
|
|
|
blank_image[:,1] *= -0.6876 |
|
|
|
|
blank_image[:,2] *= 0.6364 |
|
|
|
|
blank_image[:,3] *= 0.1380 |
|
|
|
|
return blank_image |
|
|
|
|
|
|
|
|
|
def create_cond_with_same_area_if_none(conds, c): |
|
|
|
|
if 'area' not in c[1]: |
|
|
|
|
return |
|
|
|
@ -276,11 +321,24 @@ class KSampler:
|
|
|
|
|
else: |
|
|
|
|
precision_scope = contextlib.nullcontext |
|
|
|
|
|
|
|
|
|
latent_mask = None |
|
|
|
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} |
|
|
|
|
|
|
|
|
|
if hasattr(self.model, 'concat_keys'): |
|
|
|
|
cond_concat = [] |
|
|
|
|
for ck in self.model.concat_keys: |
|
|
|
|
if denoise_mask is not None: |
|
|
|
|
latent_mask = (torch.ones_like(denoise_mask) - denoise_mask) |
|
|
|
|
if ck == "mask": |
|
|
|
|
cond_concat.append(denoise_mask[:,:1]) |
|
|
|
|
elif ck == "masked_image": |
|
|
|
|
blank_image = blank_inpaint_image_like(latent_image) |
|
|
|
|
cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image) |
|
|
|
|
else: |
|
|
|
|
if ck == "mask": |
|
|
|
|
cond_concat.append(torch.ones_like(noise)[:,:1]) |
|
|
|
|
elif ck == "masked_image": |
|
|
|
|
cond_concat.append(blank_inpaint_image_like(noise)) |
|
|
|
|
extra_args["cond_concat"] = cond_concat |
|
|
|
|
|
|
|
|
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} |
|
|
|
|
with precision_scope(self.device): |
|
|
|
|
if self.sampler == "uni_pc": |
|
|
|
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask) |
|
|
|
|