diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 08bf0fc9..a30d1d03 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -358,9 +358,6 @@ class UniPC: thresholding=False, max_val=1., variant='bh1', - noise_mask=None, - masked_image=None, - noise=None, ): """Construct a UniPC. @@ -372,9 +369,6 @@ class UniPC: self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val - self.noise_mask = noise_mask - self.masked_image = masked_image - self.noise = noise def dynamic_thresholding_fn(self, x0, t=None): """ @@ -391,10 +385,7 @@ class UniPC: """ Return the noise prediction model. """ - if self.noise_mask is not None: - return self.model(x, t) * self.noise_mask - else: - return self.model(x, t) + return self.model(x, t) def data_prediction_fn(self, x, t): """ @@ -409,8 +400,6 @@ class UniPC: s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s - if self.noise_mask is not None: - x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image return x0 def model_fn(self, x, t): @@ -723,8 +712,6 @@ class UniPC: assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): for step_index in trange(steps, disable=disable_pbar): - if self.noise_mask is not None: - x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] @@ -766,7 +753,7 @@ class UniPC: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x if callback is not None: - callback(step_index, model_prev_list[-1], x, steps) + callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]}) else: raise NotImplementedError() # if denoise_to_zero: @@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs): return (input - model(input, sigma_in, **kwargs)) / sigma -def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): +def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): timesteps = sigmas.clone() if sigmas[-1] == 0: timesteps = sigmas[:] @@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call timesteps = sigmas.clone() ns = SigmaConvert() - if image is not None: - img = image * ns.marginal_alpha(timesteps[0]) - if max_denoise: - noise_mult = 1.0 - else: - noise_mult = ns.marginal_std(timesteps[0]) - img += noise * noise_mult - else: - img = noise - + noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0) model_type = "noise" model_fn = model_wrapper( @@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call ) order = min(3, len(timesteps) - 2) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant) + x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x /= ns.marginal_alpha(timesteps[-1]) return x + +def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False): + return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2') \ No newline at end of file diff --git a/comfy/samplers.py b/comfy/samplers.py index c795f208..9c45eb60 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -513,14 +513,6 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -class UNIPC(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) - -class UNIPCBH2(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) - KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] @@ -640,9 +632,9 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): def sampler_object(name): if name == "uni_pc": - sampler = UNIPC() + sampler = KSAMPLER(uni_pc.sample_unipc) elif name == "uni_pc_bh2": - sampler = UNIPCBH2() + sampler = KSAMPLER(uni_pc.sample_unipc_bh2) elif name == "ddim": sampler = ksampler("euler", inpaint_options={"random": True}) else: