|
|
|
@ -163,11 +163,9 @@ class ResBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StageA(nn.Module): |
|
|
|
|
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, |
|
|
|
|
scale_factor=0.43): # 0.3764 |
|
|
|
|
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192): |
|
|
|
|
super().__init__() |
|
|
|
|
self.c_latent = c_latent |
|
|
|
|
self.scale_factor = scale_factor |
|
|
|
|
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] |
|
|
|
|
|
|
|
|
|
# Encoder blocks |
|
|
|
@ -214,12 +212,11 @@ class StageA(nn.Module):
|
|
|
|
|
x = self.down_blocks(x) |
|
|
|
|
if quantize: |
|
|
|
|
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) |
|
|
|
|
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 |
|
|
|
|
return qe, x, indices, vq_loss + commit_loss * 0.25 |
|
|
|
|
else: |
|
|
|
|
return x / self.scale_factor |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
def decode(self, x): |
|
|
|
|
x = x * self.scale_factor |
|
|
|
|
x = self.up_blocks(x) |
|
|
|
|
x = self.out_block(x) |
|
|
|
|
return x |
|
|
|
|