|
|
|
@ -66,6 +66,9 @@ class BatchedBrownianTree:
|
|
|
|
|
"""A wrapper around torchsde.BrownianTree that enables batches of entropy.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, x, t0, t1, seed=None, **kwargs): |
|
|
|
|
self.cpu_tree = True |
|
|
|
|
if "cpu" in kwargs: |
|
|
|
|
self.cpu_tree = kwargs.pop("cpu") |
|
|
|
|
t0, t1, self.sign = self.sort(t0, t1) |
|
|
|
|
w0 = kwargs.get('w0', torch.zeros_like(x)) |
|
|
|
|
if seed is None: |
|
|
|
@ -77,7 +80,10 @@ class BatchedBrownianTree:
|
|
|
|
|
except TypeError: |
|
|
|
|
seed = [seed] |
|
|
|
|
self.batched = False |
|
|
|
|
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] |
|
|
|
|
if self.cpu_tree: |
|
|
|
|
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] |
|
|
|
|
else: |
|
|
|
|
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def sort(a, b): |
|
|
|
@ -85,7 +91,11 @@ class BatchedBrownianTree:
|
|
|
|
|
|
|
|
|
|
def __call__(self, t0, t1): |
|
|
|
|
t0, t1, sign = self.sort(t0, t1) |
|
|
|
|
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) |
|
|
|
|
if self.cpu_tree: |
|
|
|
|
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) |
|
|
|
|
else: |
|
|
|
|
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) |
|
|
|
|
|
|
|
|
|
return w if self.batched else w[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -104,10 +114,10 @@ class BrownianTreeNoiseSampler:
|
|
|
|
|
internal timestep. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): |
|
|
|
|
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): |
|
|
|
|
self.transform = transform |
|
|
|
|
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) |
|
|
|
|
self.tree = BatchedBrownianTree(x, t0, t1, seed) |
|
|
|
|
self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) |
|
|
|
|
|
|
|
|
|
def __call__(self, sigma, sigma_next): |
|
|
|
|
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) |
|
|
|
@ -544,7 +554,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
|
|
|
|
"""DPM-Solver++ (stochastic).""" |
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
seed = extra_args.get("seed", None) |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler |
|
|
|
|
extra_args = {} if extra_args is None else extra_args |
|
|
|
|
s_in = x.new_ones([x.shape[0]]) |
|
|
|
|
sigma_fn = lambda t: t.neg().exp() |
|
|
|
@ -616,7 +626,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|
|
|
|
|
|
|
|
|
seed = extra_args.get("seed", None) |
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler |
|
|
|
|
extra_args = {} if extra_args is None else extra_args |
|
|
|
|
s_in = x.new_ones([x.shape[0]]) |
|
|
|
|
|
|
|
|
@ -651,3 +661,18 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|
|
|
|
old_denoised = denoised |
|
|
|
|
h_last = h |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): |
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler |
|
|
|
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): |
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler |
|
|
|
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|