|
|
|
@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DDIMSampler(object): |
|
|
|
|
def __init__(self, model, schedule="linear", **kwargs): |
|
|
|
|
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): |
|
|
|
|
super().__init__() |
|
|
|
|
self.model = model |
|
|
|
|
self.ddpm_num_timesteps = model.num_timesteps |
|
|
|
|
self.schedule = schedule |
|
|
|
|
self.device = device |
|
|
|
|
|
|
|
|
|
def register_buffer(self, name, attr): |
|
|
|
|
if type(attr) == torch.Tensor: |
|
|
|
|
if attr.device != torch.device("cuda"): |
|
|
|
|
attr = attr.to(torch.device("cuda")) |
|
|
|
|
if attr.device != self.device: |
|
|
|
|
attr = attr.to(self.device) |
|
|
|
|
setattr(self, name, attr) |
|
|
|
|
|
|
|
|
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): |
|
|
|
|