|
|
@ -150,7 +150,11 @@ class ResBlock(nn.Module): |
|
|
|
mods = self.gammas |
|
|
|
mods = self.gammas |
|
|
|
|
|
|
|
|
|
|
|
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] |
|
|
|
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] |
|
|
|
|
|
|
|
try: |
|
|
|
x = x + self.depthwise(x_temp) * mods[2] |
|
|
|
x = x + self.depthwise(x_temp) * mods[2] |
|
|
|
|
|
|
|
except: #operation not implemented for bf16 |
|
|
|
|
|
|
|
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype) |
|
|
|
|
|
|
|
x = x + self.depthwise[1](x_temp) * mods[2] |
|
|
|
|
|
|
|
|
|
|
|
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] |
|
|
|
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] |
|
|
|
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] |
|
|
|
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] |
|
|
|