Browse Source

Cleanup uni_pc inpainting.

This causes some small changes to the uni pc inpainting behavior but it
seems to improve results slightly.
pull/2881/head
comfyanonymous 9 months ago
parent
commit
10847dfafe
  1. 37
      comfy/extra_samplers/uni_pc.py
  2. 12
      comfy/samplers.py

37
comfy/extra_samplers/uni_pc.py

@ -358,9 +358,6 @@ class UniPC:
thresholding=False, thresholding=False,
max_val=1., max_val=1.,
variant='bh1', variant='bh1',
noise_mask=None,
masked_image=None,
noise=None,
): ):
"""Construct a UniPC. """Construct a UniPC.
@ -372,9 +369,6 @@ class UniPC:
self.predict_x0 = predict_x0 self.predict_x0 = predict_x0
self.thresholding = thresholding self.thresholding = thresholding
self.max_val = max_val self.max_val = max_val
self.noise_mask = noise_mask
self.masked_image = masked_image
self.noise = noise
def dynamic_thresholding_fn(self, x0, t=None): def dynamic_thresholding_fn(self, x0, t=None):
""" """
@ -391,10 +385,7 @@ class UniPC:
""" """
Return the noise prediction model. Return the noise prediction model.
""" """
if self.noise_mask is not None: return self.model(x, t)
return self.model(x, t) * self.noise_mask
else:
return self.model(x, t)
def data_prediction_fn(self, x, t): def data_prediction_fn(self, x, t):
""" """
@ -409,8 +400,6 @@ class UniPC:
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s x0 = torch.clamp(x0, -s, s) / s
if self.noise_mask is not None:
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
return x0 return x0
def model_fn(self, x, t): def model_fn(self, x, t):
@ -723,8 +712,6 @@ class UniPC:
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# with torch.no_grad(): # with torch.no_grad():
for step_index in trange(steps, disable=disable_pbar): for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0: if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)] model_prev_list = [self.model_fn(x, vec_t)]
@ -766,7 +753,7 @@ class UniPC:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x model_prev_list[-1] = model_x
if callback is not None: if callback is not None:
callback(step_index, model_prev_list[-1], x, steps) callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
else: else:
raise NotImplementedError() raise NotImplementedError()
# if denoise_to_zero: # if denoise_to_zero:
@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
return (input - model(input, sigma_in, **kwargs)) / sigma return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
timesteps = sigmas.clone() timesteps = sigmas.clone()
if sigmas[-1] == 0: if sigmas[-1] == 0:
timesteps = sigmas[:] timesteps = sigmas[:]
@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
timesteps = sigmas.clone() timesteps = sigmas.clone()
ns = SigmaConvert() ns = SigmaConvert()
if image is not None: noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
img = image * ns.marginal_alpha(timesteps[0])
if max_denoise:
noise_mult = 1.0
else:
noise_mult = ns.marginal_std(timesteps[0])
img += noise * noise_mult
else:
img = noise
model_type = "noise" model_type = "noise"
model_fn = model_wrapper( model_fn = model_wrapper(
@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
) )
order = min(3, len(timesteps) - 2) order = min(3, len(timesteps) - 2)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
x /= ns.marginal_alpha(timesteps[-1]) x /= ns.marginal_alpha(timesteps[-1])
return x return x
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')

12
comfy/samplers.py

@ -513,14 +513,6 @@ class Sampler:
sigma = float(sigmas[0]) sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
class UNIPCBH2(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
@ -640,9 +632,9 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
def sampler_object(name): def sampler_object(name):
if name == "uni_pc": if name == "uni_pc":
sampler = UNIPC() sampler = KSAMPLER(uni_pc.sample_unipc)
elif name == "uni_pc_bh2": elif name == "uni_pc_bh2":
sampler = UNIPCBH2() sampler = KSAMPLER(uni_pc.sample_unipc_bh2)
elif name == "ddim": elif name == "ddim":
sampler = ksampler("euler", inpaint_options={"random": True}) sampler = ksampler("euler", inpaint_options={"random": True})
else: else:

Loading…
Cancel
Save