Browse Source

Fix uni_pc sampler math. This changes the images this sampler produces.

pull/1733/head
comfyanonymous 1 year ago
parent
commit
4185324a1d
  1. 51
      comfy/extra_samplers/uni_pc.py
  2. 4
      comfy/k_diffusion/external.py
  3. 2
      comfy/samplers.py

51
comfy/extra_samplers/uni_pc.py

@ -713,8 +713,8 @@ class UniPC:
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
# t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
steps = len(timesteps) - 1
if method == 'multistep':
@ -769,8 +769,8 @@ class UniPC:
callback(step_index, model_prev_list[-1], x, steps)
else:
raise NotImplementedError()
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
# if denoise_to_zero:
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
@ -833,21 +833,33 @@ def expand_dims(v, dims):
return v[(...,) + (None,)*(dims - 1)]
class SigmaConvert:
schedule = ""
def marginal_log_mean_coeff(self, sigma):
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
def marginal_alpha(self, t):
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False
timesteps = sigmas.clone()
if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
to_zero = True
timesteps = sigmas[:]
timesteps[-1] = 0.001
else:
timesteps = sigmas.clone()
alphas_cumprod = model.inner_model.alphas_cumprod
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
ns = SigmaConvert()
if image is not None:
img = image * ns.marginal_alpha(timesteps[0])
@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else:
img = noise
if to_zero:
timesteps[-1] = (1 / len(alphas_cumprod))
device = noise.device
model_type = "noise"
model_fn = model_wrapper(
model.predict_eps_discrete_timestep,
model.predict_eps_sigma,
ns,
model_type=model_type,
guidance_type="uncond",
@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order = min(3, len(timesteps) - 1)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero:
x /= ns.marginal_alpha(timesteps[-1])
x /= ns.marginal_alpha(timesteps[-1])
return x

4
comfy/k_diffusion/external.py

@ -97,6 +97,10 @@ class DiscreteSchedule(nn.Module):
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
def predict_eps_sigma(self, input, sigma, **kwargs):
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""

2
comfy/samplers.py

@ -739,7 +739,7 @@ class KSampler:
sigmas = None
discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
steps += 1
discard_penultimate_sigma = True

Loading…
Cancel
Save