|
|
|
@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs):
|
|
|
|
|
else: |
|
|
|
|
raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
|
|
|
|
|
def cast_bias_weight(s, input): |
|
|
|
|
bias = None |
|
|
|
|
if s.bias is not None: |
|
|
|
|
bias = s.bias.to(device=input.device, dtype=input.dtype) |
|
|
|
|
weight = s.weight.to(device=input.device, dtype=input.dtype) |
|
|
|
|
return weight, bias |
|
|
|
|
|
|
|
|
|
class manual_cast: |
|
|
|
|
class Linear(Linear): |
|
|
|
|
def forward(self, input): |
|
|
|
|
weight, bias = cast_bias_weight(self, input) |
|
|
|
|
return torch.nn.functional.linear(input, weight, bias) |
|
|
|
|
|
|
|
|
|
class Conv2d(Conv2d): |
|
|
|
|
def forward(self, input): |
|
|
|
|
weight, bias = cast_bias_weight(self, input) |
|
|
|
|
return self._conv_forward(input, weight, bias) |
|
|
|
|
|
|
|
|
|
class Conv3d(Conv3d): |
|
|
|
|
def forward(self, input): |
|
|
|
|
weight, bias = cast_bias_weight(self, input) |
|
|
|
|
return self._conv_forward(input, weight, bias) |
|
|
|
|
|
|
|
|
|
class GroupNorm(GroupNorm): |
|
|
|
|
def forward(self, input): |
|
|
|
|
weight, bias = cast_bias_weight(self, input) |
|
|
|
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) |
|
|
|
|
|
|
|
|
|
class LayerNorm(LayerNorm): |
|
|
|
|
def forward(self, input): |
|
|
|
|
weight, bias = cast_bias_weight(self, input) |
|
|
|
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) |
|
|
|
|
|
|
|
|
|
@contextmanager |
|
|
|
|
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way |
|
|
|
|
old_torch_nn_linear = torch.nn.Linear |
|
|
|
|