From 1f6a467e92a2adb73d8fc312b6a3ae46c59d14c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 9 Feb 2023 13:47:36 -0500 Subject: [PATCH] Update ldm dir with latest upstream stable diffusion changes. --- comfy/ldm/models/diffusion/ddim.py | 7 ++++--- comfy/ldm/models/diffusion/ddpm.py | 8 +++++++- comfy/ldm/models/diffusion/dpm_solver/sampler.py | 7 ++++--- comfy/ldm/models/diffusion/plms.py | 7 ++++--- comfy/ldm/modules/diffusionmodules/openaimodel.py | 2 ++ 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 27ead0ea..c6cfd571 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != self.device: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index e297d27e..074919d0 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1331,7 +1331,13 @@ class DiffusionWrapper(torch.nn.Module): cc = torch.cat(c_crossattn, 1) else: cc = c_crossattn - out = self.diffusion_model(x, t, context=cc) + if hasattr(self, "scripted_diffusion_model"): + # TorchScript changes names of the arguments + # with argument cc defined as context=cc scripted model will produce + # an error: RuntimeError: forward() is missing value for argument 'argument_3'. + out = self.scripted_diffusion_model(x, t, cc) + else: + out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) diff --git a/comfy/ldm/models/diffusion/dpm_solver/sampler.py b/comfy/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8c..4270c618 100644 --- a/comfy/ldm/models/diffusion/dpm_solver/sampler.py +++ b/comfy/ldm/models/diffusion/dpm_solver/sampler.py @@ -11,16 +11,17 @@ MODEL_TYPES = { class DPMSolverSampler(object): - def __init__(self, model, **kwargs): + def __init__(self, model, device=torch.device("cuda"), **kwargs): super().__init__() self.model = model + self.device = device to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != self.device: + attr = attr.to(self.device) setattr(self, name, attr) @torch.no_grad() diff --git a/comfy/ldm/models/diffusion/plms.py b/comfy/ldm/models/diffusion/plms.py index 7002a365..9d31b399 100644 --- a/comfy/ldm/models/diffusion/plms.py +++ b/comfy/ldm/models/diffusion/plms.py @@ -10,16 +10,17 @@ from ldm.models.diffusion.sampling_util import norm_thresholding class PLMSSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != self.device: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 7df6b5ab..764a34b8 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -454,6 +454,7 @@ class UNetModel(nn.Module): num_classes=None, use_checkpoint=False, use_fp16=False, + use_bf16=False, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, @@ -518,6 +519,7 @@ class UNetModel(nn.Module): self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = th.bfloat16 if use_bf16 else self.dtype self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample