|
|
|
@ -56,7 +56,18 @@ class Upsample(nn.Module):
|
|
|
|
|
padding=1) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
try: |
|
|
|
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
|
|
|
|
except: #operation not implemented for bf16 |
|
|
|
|
b, c, h, w = x.shape |
|
|
|
|
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device) |
|
|
|
|
split = 8 |
|
|
|
|
l = out.shape[1] // split |
|
|
|
|
for i in range(0, out.shape[1], l): |
|
|
|
|
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype) |
|
|
|
|
del x |
|
|
|
|
x = out |
|
|
|
|
|
|
|
|
|
if self.with_conv: |
|
|
|
|
x = self.conv(x) |
|
|
|
|
return x |
|
|
|
|