diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index b596408d..9c2ea66b 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -56,7 +56,18 @@ class Upsample(nn.Module): padding=1) def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + try: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + except: #operation not implemented for bf16 + b, c, h, w = x.shape + out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device) + split = 8 + l = out.shape[1] // split + for i in range(0, out.shape[1], l): + out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype) + del x + x = out + if self.with_conv: x = self.conv(x) return x