|
|
@ -358,9 +358,6 @@ class UniPC: |
|
|
|
thresholding=False, |
|
|
|
thresholding=False, |
|
|
|
max_val=1., |
|
|
|
max_val=1., |
|
|
|
variant='bh1', |
|
|
|
variant='bh1', |
|
|
|
noise_mask=None, |
|
|
|
|
|
|
|
masked_image=None, |
|
|
|
|
|
|
|
noise=None, |
|
|
|
|
|
|
|
): |
|
|
|
): |
|
|
|
"""Construct a UniPC. |
|
|
|
"""Construct a UniPC. |
|
|
|
|
|
|
|
|
|
|
@ -372,9 +369,6 @@ class UniPC: |
|
|
|
self.predict_x0 = predict_x0 |
|
|
|
self.predict_x0 = predict_x0 |
|
|
|
self.thresholding = thresholding |
|
|
|
self.thresholding = thresholding |
|
|
|
self.max_val = max_val |
|
|
|
self.max_val = max_val |
|
|
|
self.noise_mask = noise_mask |
|
|
|
|
|
|
|
self.masked_image = masked_image |
|
|
|
|
|
|
|
self.noise = noise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dynamic_thresholding_fn(self, x0, t=None): |
|
|
|
def dynamic_thresholding_fn(self, x0, t=None): |
|
|
|
""" |
|
|
|
""" |
|
|
@ -391,10 +385,7 @@ class UniPC: |
|
|
|
""" |
|
|
|
""" |
|
|
|
Return the noise prediction model. |
|
|
|
Return the noise prediction model. |
|
|
|
""" |
|
|
|
""" |
|
|
|
if self.noise_mask is not None: |
|
|
|
return self.model(x, t) |
|
|
|
return self.model(x, t) * self.noise_mask |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
return self.model(x, t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def data_prediction_fn(self, x, t): |
|
|
|
def data_prediction_fn(self, x, t): |
|
|
|
""" |
|
|
|
""" |
|
|
@ -409,8 +400,6 @@ class UniPC: |
|
|
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) |
|
|
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) |
|
|
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) |
|
|
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) |
|
|
|
x0 = torch.clamp(x0, -s, s) / s |
|
|
|
x0 = torch.clamp(x0, -s, s) / s |
|
|
|
if self.noise_mask is not None: |
|
|
|
|
|
|
|
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image |
|
|
|
|
|
|
|
return x0 |
|
|
|
return x0 |
|
|
|
|
|
|
|
|
|
|
|
def model_fn(self, x, t): |
|
|
|
def model_fn(self, x, t): |
|
|
@ -723,8 +712,6 @@ class UniPC: |
|
|
|
assert timesteps.shape[0] - 1 == steps |
|
|
|
assert timesteps.shape[0] - 1 == steps |
|
|
|
# with torch.no_grad(): |
|
|
|
# with torch.no_grad(): |
|
|
|
for step_index in trange(steps, disable=disable_pbar): |
|
|
|
for step_index in trange(steps, disable=disable_pbar): |
|
|
|
if self.noise_mask is not None: |
|
|
|
|
|
|
|
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) |
|
|
|
|
|
|
|
if step_index == 0: |
|
|
|
if step_index == 0: |
|
|
|
vec_t = timesteps[0].expand((x.shape[0])) |
|
|
|
vec_t = timesteps[0].expand((x.shape[0])) |
|
|
|
model_prev_list = [self.model_fn(x, vec_t)] |
|
|
|
model_prev_list = [self.model_fn(x, vec_t)] |
|
|
@ -766,7 +753,7 @@ class UniPC: |
|
|
|
model_x = self.model_fn(x, vec_t) |
|
|
|
model_x = self.model_fn(x, vec_t) |
|
|
|
model_prev_list[-1] = model_x |
|
|
|
model_prev_list[-1] = model_x |
|
|
|
if callback is not None: |
|
|
|
if callback is not None: |
|
|
|
callback(step_index, model_prev_list[-1], x, steps) |
|
|
|
callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]}) |
|
|
|
else: |
|
|
|
else: |
|
|
|
raise NotImplementedError() |
|
|
|
raise NotImplementedError() |
|
|
|
# if denoise_to_zero: |
|
|
|
# if denoise_to_zero: |
|
|
@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs): |
|
|
|
return (input - model(input, sigma_in, **kwargs)) / sigma |
|
|
|
return (input - model(input, sigma_in, **kwargs)) / sigma |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): |
|
|
|
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): |
|
|
|
timesteps = sigmas.clone() |
|
|
|
timesteps = sigmas.clone() |
|
|
|
if sigmas[-1] == 0: |
|
|
|
if sigmas[-1] == 0: |
|
|
|
timesteps = sigmas[:] |
|
|
|
timesteps = sigmas[:] |
|
|
@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call |
|
|
|
timesteps = sigmas.clone() |
|
|
|
timesteps = sigmas.clone() |
|
|
|
ns = SigmaConvert() |
|
|
|
ns = SigmaConvert() |
|
|
|
|
|
|
|
|
|
|
|
if image is not None: |
|
|
|
noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0) |
|
|
|
img = image * ns.marginal_alpha(timesteps[0]) |
|
|
|
|
|
|
|
if max_denoise: |
|
|
|
|
|
|
|
noise_mult = 1.0 |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
noise_mult = ns.marginal_std(timesteps[0]) |
|
|
|
|
|
|
|
img += noise * noise_mult |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
img = noise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_type = "noise" |
|
|
|
model_type = "noise" |
|
|
|
|
|
|
|
|
|
|
|
model_fn = model_wrapper( |
|
|
|
model_fn = model_wrapper( |
|
|
@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
order = min(3, len(timesteps) - 2) |
|
|
|
order = min(3, len(timesteps) - 2) |
|
|
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) |
|
|
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant) |
|
|
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) |
|
|
|
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) |
|
|
|
x /= ns.marginal_alpha(timesteps[-1]) |
|
|
|
x /= ns.marginal_alpha(timesteps[-1]) |
|
|
|
return x |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False): |
|
|
|
|
|
|
|
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2') |