|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
import torch |
|
|
|
|
from contextlib import contextmanager |
|
|
|
|
|
|
|
|
|
class Linear(torch.nn.Module): |
|
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True, |
|
|
|
@ -19,3 +20,13 @@ class Linear(torch.nn.Module):
|
|
|
|
|
class Conv2d(torch.nn.Conv2d): |
|
|
|
|
def reset_parameters(self): |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager |
|
|
|
|
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way |
|
|
|
|
old_torch_nn_linear = torch.nn.Linear |
|
|
|
|
torch.nn.Linear = Linear |
|
|
|
|
try: |
|
|
|
|
yield |
|
|
|
|
finally: |
|
|
|
|
torch.nn.Linear = old_torch_nn_linear |
|
|
|
|