|
|
@ -7,9 +7,10 @@ import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
import comfy.utils |
|
|
|
import comfy.utils |
|
|
|
|
|
|
|
import comfy.ops |
|
|
|
|
|
|
|
|
|
|
|
def conv(n_in, n_out, **kwargs): |
|
|
|
def conv(n_in, n_out, **kwargs): |
|
|
|
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) |
|
|
|
return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class Clamp(nn.Module): |
|
|
|
class Clamp(nn.Module): |
|
|
|
def forward(self, x): |
|
|
|
def forward(self, x): |
|
|
@ -19,7 +20,7 @@ class Block(nn.Module): |
|
|
|
def __init__(self, n_in, n_out): |
|
|
|
def __init__(self, n_in, n_out): |
|
|
|
super().__init__() |
|
|
|
super().__init__() |
|
|
|
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) |
|
|
|
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) |
|
|
|
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() |
|
|
|
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() |
|
|
|
self.fuse = nn.ReLU() |
|
|
|
self.fuse = nn.ReLU() |
|
|
|
def forward(self, x): |
|
|
|
def forward(self, x): |
|
|
|
return self.fuse(self.conv(x) + self.skip(x)) |
|
|
|
return self.fuse(self.conv(x) + self.skip(x)) |
|
|
|