diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 46f3097a..8f96c54e 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -7,9 +7,10 @@ import torch import torch.nn as nn import comfy.utils +import comfy.ops def conv(n_in, n_out, **kwargs): - return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs) class Clamp(nn.Module): def forward(self, x): @@ -19,7 +20,7 @@ class Block(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) - self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() def forward(self, x): return self.fuse(self.conv(x) + self.skip(x))