Browse Source

Move cascade scale factor from stage_a to latent_formats.py

pull/3069/head
comfyanonymous 8 months ago
parent
commit
d7897fff2c
  1. 2
      comfy/latent_formats.py
  2. 9
      comfy/ldm/cascade/stage_a.py

2
comfy/latent_formats.py

@ -95,7 +95,7 @@ class SC_Prior(LatentFormat):
class SC_B(LatentFormat): class SC_B(LatentFormat):
def __init__(self): def __init__(self):
self.scale_factor = 1.0 self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [ self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023], [ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195], [-0.2093, -0.0222, -0.0195],

9
comfy/ldm/cascade/stage_a.py

@ -163,11 +163,9 @@ class ResBlock(nn.Module):
class StageA(nn.Module): class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
scale_factor=0.43): # 0.3764
super().__init__() super().__init__()
self.c_latent = c_latent self.c_latent = c_latent
self.scale_factor = scale_factor
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
# Encoder blocks # Encoder blocks
@ -214,12 +212,11 @@ class StageA(nn.Module):
x = self.down_blocks(x) x = self.down_blocks(x)
if quantize: if quantize:
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 return qe, x, indices, vq_loss + commit_loss * 0.25
else: else:
return x / self.scale_factor return x
def decode(self, x): def decode(self, x):
x = x * self.scale_factor
x = self.up_blocks(x) x = self.up_blocks(x)
x = self.out_block(x) x = self.out_block(x)
return x return x

Loading…
Cancel
Save