|
|
@ -31,13 +31,13 @@ def cast_bias_weight(s, input): |
|
|
|
weight = s.weight_function(weight) |
|
|
|
weight = s.weight_function(weight) |
|
|
|
return weight, bias |
|
|
|
return weight, bias |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CastWeightBiasOp: |
|
|
|
class disable_weight_init: |
|
|
|
|
|
|
|
class Linear(torch.nn.Linear): |
|
|
|
|
|
|
|
comfy_cast_weights = False |
|
|
|
comfy_cast_weights = False |
|
|
|
weight_function = None |
|
|
|
weight_function = None |
|
|
|
bias_function = None |
|
|
|
bias_function = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class disable_weight_init: |
|
|
|
|
|
|
|
class Linear(torch.nn.Linear, CastWeightBiasOp): |
|
|
|
def reset_parameters(self): |
|
|
|
def reset_parameters(self): |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
@ -51,11 +51,7 @@ class disable_weight_init: |
|
|
|
else: |
|
|
|
else: |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class Conv2d(torch.nn.Conv2d): |
|
|
|
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): |
|
|
|
comfy_cast_weights = False |
|
|
|
|
|
|
|
weight_function = None |
|
|
|
|
|
|
|
bias_function = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self): |
|
|
|
def reset_parameters(self): |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
@ -69,11 +65,7 @@ class disable_weight_init: |
|
|
|
else: |
|
|
|
else: |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class Conv3d(torch.nn.Conv3d): |
|
|
|
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): |
|
|
|
comfy_cast_weights = False |
|
|
|
|
|
|
|
weight_function = None |
|
|
|
|
|
|
|
bias_function = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self): |
|
|
|
def reset_parameters(self): |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
@ -87,11 +79,7 @@ class disable_weight_init: |
|
|
|
else: |
|
|
|
else: |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class GroupNorm(torch.nn.GroupNorm): |
|
|
|
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): |
|
|
|
comfy_cast_weights = False |
|
|
|
|
|
|
|
weight_function = None |
|
|
|
|
|
|
|
bias_function = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self): |
|
|
|
def reset_parameters(self): |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
@ -106,11 +94,7 @@ class disable_weight_init: |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(torch.nn.LayerNorm): |
|
|
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): |
|
|
|
comfy_cast_weights = False |
|
|
|
|
|
|
|
weight_function = None |
|
|
|
|
|
|
|
bias_function = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self): |
|
|
|
def reset_parameters(self): |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
@ -128,11 +112,7 @@ class disable_weight_init: |
|
|
|
else: |
|
|
|
else: |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class ConvTranspose2d(torch.nn.ConvTranspose2d): |
|
|
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): |
|
|
|
comfy_cast_weights = False |
|
|
|
|
|
|
|
weight_function = None |
|
|
|
|
|
|
|
bias_function = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self): |
|
|
|
def reset_parameters(self): |
|
|
|
return None |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|