|
|
|
@ -134,7 +134,7 @@ class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
|
|
|
|
"""A wrapper for CompVis diffusion models.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, model, quantize=False, device='cpu'): |
|
|
|
|
super().__init__(model, model.alphas_cumprod, quantize=quantize) |
|
|
|
|
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize) |
|
|
|
|
|
|
|
|
|
def get_eps(self, *args, **kwargs): |
|
|
|
|
return self.inner_model.apply_model(*args, **kwargs) |
|
|
|
@ -173,7 +173,7 @@ class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
|
|
|
|
"""A wrapper for CompVis diffusion models that output v.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, model, quantize=False, device='cpu'): |
|
|
|
|
super().__init__(model, model.alphas_cumprod, quantize=quantize) |
|
|
|
|
super().__init__(model, model.alphas_cumprod.float(), quantize=quantize) |
|
|
|
|
|
|
|
|
|
def get_v(self, x, t, cond, **kwargs): |
|
|
|
|
return self.inner_model.apply_model(x, t, cond) |
|
|
|
|