comfyanonymous
9 months ago
1 changed files with 96 additions and 0 deletions
@ -0,0 +1,96 @@
|
||||
""" |
||||
This file is part of ComfyUI. |
||||
Copyright (C) 2024 Stability AI |
||||
|
||||
This program is free software: you can redistribute it and/or modify |
||||
it under the terms of the GNU General Public License as published by |
||||
the Free Software Foundation, either version 3 of the License, or |
||||
(at your option) any later version. |
||||
|
||||
This program is distributed in the hope that it will be useful, |
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of |
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
||||
GNU General Public License for more details. |
||||
|
||||
You should have received a copy of the GNU General Public License |
||||
along with this program. If not, see <https://www.gnu.org/licenses/>. |
||||
""" |
||||
import torch |
||||
import torchvision |
||||
from torch import nn |
||||
|
||||
|
||||
# EfficientNet |
||||
class EfficientNetEncoder(nn.Module): |
||||
def __init__(self, c_latent=16): |
||||
super().__init__() |
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval() |
||||
self.mapper = nn.Sequential( |
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), |
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 |
||||
) |
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) |
||||
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225])) |
||||
|
||||
def forward(self, x): |
||||
x = x * 0.5 + 0.5 |
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) |
||||
o = self.mapper(self.backbone(x)) |
||||
print(o.shape) |
||||
return o |
||||
|
||||
|
||||
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 |
||||
class Previewer(nn.Module): |
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3): |
||||
super().__init__() |
||||
self.blocks = nn.Sequential( |
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels |
||||
nn.GELU(), |
||||
nn.BatchNorm2d(c_hidden), |
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), |
||||
nn.GELU(), |
||||
nn.BatchNorm2d(c_hidden), |
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 |
||||
nn.GELU(), |
||||
nn.BatchNorm2d(c_hidden // 2), |
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), |
||||
nn.GELU(), |
||||
nn.BatchNorm2d(c_hidden // 2), |
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 |
||||
nn.GELU(), |
||||
nn.BatchNorm2d(c_hidden // 4), |
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), |
||||
nn.GELU(), |
||||
nn.BatchNorm2d(c_hidden // 4), |
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 |
||||
nn.GELU(), |
||||
nn.BatchNorm2d(c_hidden // 4), |
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), |
||||
nn.GELU(), |
||||
nn.BatchNorm2d(c_hidden // 4), |
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), |
||||
) |
||||
|
||||
def forward(self, x): |
||||
return (self.blocks(x) - 0.5) * 2.0 |
||||
|
||||
class StageC_coder(nn.Module): |
||||
def __init__(self): |
||||
super().__init__() |
||||
self.previewer = Previewer() |
||||
self.encoder = EfficientNetEncoder() |
||||
|
||||
def encode(self, x): |
||||
return self.encoder(x) |
||||
|
||||
def decode(self, x): |
||||
return self.previewer(x) |
Loading…
Reference in new issue