Browse Source

Make --bf16-vae work on torch 2.0

pull/1353/head
comfyanonymous 1 year ago
parent
commit
d935ba50c4
  1. 13
      comfy/ldm/modules/diffusionmodules/model.py

13
comfy/ldm/modules/diffusionmodules/model.py

@ -56,7 +56,18 @@ class Upsample(nn.Module):
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
if self.with_conv:
x = self.conv(x)
return x

Loading…
Cancel
Save