|
|
|
@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
|
|
|
|
|
self.conditioning_key = conditioning_key |
|
|
|
|
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] |
|
|
|
|
|
|
|
|
|
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): |
|
|
|
|
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}): |
|
|
|
|
if self.conditioning_key is None: |
|
|
|
|
out = self.diffusion_model(x, t, control=control) |
|
|
|
|
out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options) |
|
|
|
|
elif self.conditioning_key == 'concat': |
|
|
|
|
xc = torch.cat([x] + c_concat, dim=1) |
|
|
|
|
out = self.diffusion_model(xc, t, control=control) |
|
|
|
|
out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options) |
|
|
|
|
elif self.conditioning_key == 'crossattn': |
|
|
|
|
if not self.sequential_cross_attn: |
|
|
|
|
cc = torch.cat(c_crossattn, 1) |
|
|
|
@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
|
|
|
|
|
# TorchScript changes names of the arguments |
|
|
|
|
# with argument cc defined as context=cc scripted model will produce |
|
|
|
|
# an error: RuntimeError: forward() is missing value for argument 'argument_3'. |
|
|
|
|
out = self.scripted_diffusion_model(x, t, cc, control=control) |
|
|
|
|
out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options) |
|
|
|
|
else: |
|
|
|
|
out = self.diffusion_model(x, t, context=cc, control=control) |
|
|
|
|
out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options) |
|
|
|
|
elif self.conditioning_key == 'hybrid': |
|
|
|
|
xc = torch.cat([x] + c_concat, dim=1) |
|
|
|
|
cc = torch.cat(c_crossattn, 1) |
|
|
|
|
out = self.diffusion_model(xc, t, context=cc, control=control) |
|
|
|
|
out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options) |
|
|
|
|
elif self.conditioning_key == 'hybrid-adm': |
|
|
|
|
assert c_adm is not None |
|
|
|
|
xc = torch.cat([x] + c_concat, dim=1) |
|
|
|
|
cc = torch.cat(c_crossattn, 1) |
|
|
|
|
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) |
|
|
|
|
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options) |
|
|
|
|
elif self.conditioning_key == 'crossattn-adm': |
|
|
|
|
assert c_adm is not None |
|
|
|
|
cc = torch.cat(c_crossattn, 1) |
|
|
|
|
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) |
|
|
|
|
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options) |
|
|
|
|
elif self.conditioning_key == 'adm': |
|
|
|
|
cc = c_crossattn[0] |
|
|
|
|
out = self.diffusion_model(x, t, y=cc, control=control) |
|
|
|
|
out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options) |
|
|
|
|
else: |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|