|
|
|
@ -713,8 +713,8 @@ class UniPC:
|
|
|
|
|
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', |
|
|
|
|
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False |
|
|
|
|
): |
|
|
|
|
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end |
|
|
|
|
t_T = self.noise_schedule.T if t_start is None else t_start |
|
|
|
|
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end |
|
|
|
|
# t_T = self.noise_schedule.T if t_start is None else t_start |
|
|
|
|
device = x.device |
|
|
|
|
steps = len(timesteps) - 1 |
|
|
|
|
if method == 'multistep': |
|
|
|
@ -769,8 +769,8 @@ class UniPC:
|
|
|
|
|
callback(step_index, model_prev_list[-1], x, steps) |
|
|
|
|
else: |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
if denoise_to_zero: |
|
|
|
|
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) |
|
|
|
|
# if denoise_to_zero: |
|
|
|
|
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -833,21 +833,33 @@ def expand_dims(v, dims):
|
|
|
|
|
return v[(...,) + (None,)*(dims - 1)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SigmaConvert: |
|
|
|
|
schedule = "" |
|
|
|
|
def marginal_log_mean_coeff(self, sigma): |
|
|
|
|
return 0.5 * torch.log(1 / ((sigma * sigma) + 1)) |
|
|
|
|
|
|
|
|
|
def marginal_alpha(self, t): |
|
|
|
|
return torch.exp(self.marginal_log_mean_coeff(t)) |
|
|
|
|
|
|
|
|
|
def marginal_std(self, t): |
|
|
|
|
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) |
|
|
|
|
|
|
|
|
|
def marginal_lambda(self, t): |
|
|
|
|
""" |
|
|
|
|
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. |
|
|
|
|
""" |
|
|
|
|
log_mean_coeff = self.marginal_log_mean_coeff(t) |
|
|
|
|
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) |
|
|
|
|
return log_mean_coeff - log_std |
|
|
|
|
|
|
|
|
|
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): |
|
|
|
|
to_zero = False |
|
|
|
|
timesteps = sigmas.clone() |
|
|
|
|
if sigmas[-1] == 0: |
|
|
|
|
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] |
|
|
|
|
to_zero = True |
|
|
|
|
timesteps = sigmas[:] |
|
|
|
|
timesteps[-1] = 0.001 |
|
|
|
|
else: |
|
|
|
|
timesteps = sigmas.clone() |
|
|
|
|
|
|
|
|
|
alphas_cumprod = model.inner_model.alphas_cumprod |
|
|
|
|
|
|
|
|
|
for s in range(timesteps.shape[0]): |
|
|
|
|
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod)) |
|
|
|
|
|
|
|
|
|
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) |
|
|
|
|
ns = SigmaConvert() |
|
|
|
|
|
|
|
|
|
if image is not None: |
|
|
|
|
img = image * ns.marginal_alpha(timesteps[0]) |
|
|
|
@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|
|
|
|
else: |
|
|
|
|
img = noise |
|
|
|
|
|
|
|
|
|
if to_zero: |
|
|
|
|
timesteps[-1] = (1 / len(alphas_cumprod)) |
|
|
|
|
|
|
|
|
|
device = noise.device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_type = "noise" |
|
|
|
|
|
|
|
|
|
model_fn = model_wrapper( |
|
|
|
|
model.predict_eps_discrete_timestep, |
|
|
|
|
model.predict_eps_sigma, |
|
|
|
|
ns, |
|
|
|
|
model_type=model_type, |
|
|
|
|
guidance_type="uncond", |
|
|
|
@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|
|
|
|
order = min(3, len(timesteps) - 1) |
|
|
|
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) |
|
|
|
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) |
|
|
|
|
if not to_zero: |
|
|
|
|
x /= ns.marginal_alpha(timesteps[-1]) |
|
|
|
|
x /= ns.marginal_alpha(timesteps[-1]) |
|
|
|
|
return x |
|
|
|
|