comfyanonymous
9 months ago
2 changed files with 272 additions and 6 deletions
@ -0,0 +1,254 @@
|
||||
""" |
||||
This file is part of ComfyUI. |
||||
Copyright (C) 2024 Stability AI |
||||
|
||||
This program is free software: you can redistribute it and/or modify |
||||
it under the terms of the GNU General Public License as published by |
||||
the Free Software Foundation, either version 3 of the License, or |
||||
(at your option) any later version. |
||||
|
||||
This program is distributed in the hope that it will be useful, |
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of |
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
||||
GNU General Public License for more details. |
||||
|
||||
You should have received a copy of the GNU General Public License |
||||
along with this program. If not, see <https://www.gnu.org/licenses/>. |
||||
""" |
||||
|
||||
import torch |
||||
from torch import nn |
||||
from torch.autograd import Function |
||||
|
||||
class vector_quantize(Function): |
||||
@staticmethod |
||||
def forward(ctx, x, codebook): |
||||
with torch.no_grad(): |
||||
codebook_sqr = torch.sum(codebook ** 2, dim=1) |
||||
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) |
||||
|
||||
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) |
||||
_, indices = dist.min(dim=1) |
||||
|
||||
ctx.save_for_backward(indices, codebook) |
||||
ctx.mark_non_differentiable(indices) |
||||
|
||||
nn = torch.index_select(codebook, 0, indices) |
||||
return nn, indices |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output, grad_indices): |
||||
grad_inputs, grad_codebook = None, None |
||||
|
||||
if ctx.needs_input_grad[0]: |
||||
grad_inputs = grad_output.clone() |
||||
if ctx.needs_input_grad[1]: |
||||
# Gradient wrt. the codebook |
||||
indices, codebook = ctx.saved_tensors |
||||
|
||||
grad_codebook = torch.zeros_like(codebook) |
||||
grad_codebook.index_add_(0, indices, grad_output) |
||||
|
||||
return (grad_inputs, grad_codebook) |
||||
|
||||
|
||||
class VectorQuantize(nn.Module): |
||||
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): |
||||
""" |
||||
Takes an input of variable size (as long as the last dimension matches the embedding size). |
||||
Returns one tensor containing the nearest neigbour embeddings to each of the inputs, |
||||
with the same size as the input, vq and commitment components for the loss as a touple |
||||
in the second output and the indices of the quantized vectors in the third: |
||||
quantized, (vq_loss, commit_loss), indices |
||||
""" |
||||
super(VectorQuantize, self).__init__() |
||||
|
||||
self.codebook = nn.Embedding(k, embedding_size) |
||||
self.codebook.weight.data.uniform_(-1./k, 1./k) |
||||
self.vq = vector_quantize.apply |
||||
|
||||
self.ema_decay = ema_decay |
||||
self.ema_loss = ema_loss |
||||
if ema_loss: |
||||
self.register_buffer('ema_element_count', torch.ones(k)) |
||||
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) |
||||
|
||||
def _laplace_smoothing(self, x, epsilon): |
||||
n = torch.sum(x) |
||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n) |
||||
|
||||
def _updateEMA(self, z_e_x, indices): |
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() |
||||
elem_count = mask.sum(dim=0) |
||||
weight_sum = torch.mm(mask.t(), z_e_x) |
||||
|
||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) |
||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) |
||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) |
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) |
||||
|
||||
def idx2vq(self, idx, dim=-1): |
||||
q_idx = self.codebook(idx) |
||||
if dim != -1: |
||||
q_idx = q_idx.movedim(-1, dim) |
||||
return q_idx |
||||
|
||||
def forward(self, x, get_losses=True, dim=-1): |
||||
if dim != -1: |
||||
x = x.movedim(dim, -1) |
||||
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x |
||||
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) |
||||
vq_loss, commit_loss = None, None |
||||
if self.ema_loss and self.training: |
||||
self._updateEMA(z_e_x.detach(), indices.detach()) |
||||
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss |
||||
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) |
||||
if get_losses: |
||||
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() |
||||
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() |
||||
|
||||
z_q_x = z_q_x.view(x.shape) |
||||
if dim != -1: |
||||
z_q_x = z_q_x.movedim(-1, dim) |
||||
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) |
||||
|
||||
|
||||
class ResBlock(nn.Module): |
||||
def __init__(self, c, c_hidden): |
||||
super().__init__() |
||||
# depthwise/attention |
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) |
||||
self.depthwise = nn.Sequential( |
||||
nn.ReplicationPad2d(1), |
||||
nn.Conv2d(c, c, kernel_size=3, groups=c) |
||||
) |
||||
|
||||
# channelwise |
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) |
||||
self.channelwise = nn.Sequential( |
||||
nn.Linear(c, c_hidden), |
||||
nn.GELU(), |
||||
nn.Linear(c_hidden, c), |
||||
) |
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) |
||||
|
||||
# Init weights |
||||
def _basic_init(module): |
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): |
||||
torch.nn.init.xavier_uniform_(module.weight) |
||||
if module.bias is not None: |
||||
nn.init.constant_(module.bias, 0) |
||||
|
||||
self.apply(_basic_init) |
||||
|
||||
def _norm(self, x, norm): |
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
||||
|
||||
def forward(self, x): |
||||
mods = self.gammas |
||||
|
||||
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] |
||||
x = x + self.depthwise(x_temp) * mods[2] |
||||
|
||||
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] |
||||
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] |
||||
|
||||
return x |
||||
|
||||
|
||||
class StageA(nn.Module): |
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, |
||||
scale_factor=0.43): # 0.3764 |
||||
super().__init__() |
||||
self.c_latent = c_latent |
||||
self.scale_factor = scale_factor |
||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] |
||||
|
||||
# Encoder blocks |
||||
self.in_block = nn.Sequential( |
||||
nn.PixelUnshuffle(2), |
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) |
||||
) |
||||
down_blocks = [] |
||||
for i in range(levels): |
||||
if i > 0: |
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) |
||||
block = ResBlock(c_levels[i], c_levels[i] * 4) |
||||
down_blocks.append(block) |
||||
down_blocks.append(nn.Sequential( |
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), |
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 |
||||
)) |
||||
self.down_blocks = nn.Sequential(*down_blocks) |
||||
self.down_blocks[0] |
||||
|
||||
self.codebook_size = codebook_size |
||||
self.vquantizer = VectorQuantize(c_latent, k=codebook_size) |
||||
|
||||
# Decoder blocks |
||||
up_blocks = [nn.Sequential( |
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) |
||||
)] |
||||
for i in range(levels): |
||||
for j in range(bottleneck_blocks if i == 0 else 1): |
||||
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) |
||||
up_blocks.append(block) |
||||
if i < levels - 1: |
||||
up_blocks.append( |
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, |
||||
padding=1)) |
||||
self.up_blocks = nn.Sequential(*up_blocks) |
||||
self.out_block = nn.Sequential( |
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), |
||||
nn.PixelShuffle(2), |
||||
) |
||||
|
||||
def encode(self, x, quantize=False): |
||||
x = self.in_block(x) |
||||
x = self.down_blocks(x) |
||||
if quantize: |
||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) |
||||
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 |
||||
else: |
||||
return x / self.scale_factor |
||||
|
||||
def decode(self, x): |
||||
x = x * self.scale_factor |
||||
x = self.up_blocks(x) |
||||
x = self.out_block(x) |
||||
return x |
||||
|
||||
def forward(self, x, quantize=False): |
||||
qe, x, _, vq_loss = self.encode(x, quantize) |
||||
x = self.decode(qe) |
||||
return x, vq_loss |
||||
|
||||
|
||||
class Discriminator(nn.Module): |
||||
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): |
||||
super().__init__() |
||||
d = max(depth - 3, 3) |
||||
layers = [ |
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), |
||||
nn.LeakyReLU(0.2), |
||||
] |
||||
for i in range(depth - 1): |
||||
c_in = c_hidden // (2 ** max((d - i), 0)) |
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0)) |
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) |
||||
layers.append(nn.InstanceNorm2d(c_out)) |
||||
layers.append(nn.LeakyReLU(0.2)) |
||||
self.encoder = nn.Sequential(*layers) |
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) |
||||
self.logits = nn.Sigmoid() |
||||
|
||||
def forward(self, x, cond=None): |
||||
x = self.encoder(x) |
||||
if cond is not None: |
||||
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) |
||||
x = torch.cat([x, cond], dim=1) |
||||
x = self.shuffle(x) |
||||
x = self.logits(x) |
||||
return x |
Loading…
Reference in new issue