|
|
|
@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
padding = kernel_size // 2 |
|
|
|
|
|
|
|
|
|
self.in_layers = nn.Sequential( |
|
|
|
|
nn.GroupNorm(32, channels, dtype=dtype, device=device), |
|
|
|
|
operations.GroupNorm(32, channels, dtype=dtype, device=device), |
|
|
|
|
nn.SiLU(), |
|
|
|
|
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), |
|
|
|
|
) |
|
|
|
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
), |
|
|
|
|
) |
|
|
|
|
self.out_layers = nn.Sequential( |
|
|
|
|
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), |
|
|
|
|
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), |
|
|
|
|
nn.SiLU(), |
|
|
|
|
nn.Dropout(p=dropout), |
|
|
|
|
zero_module( |
|
|
|
|
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) |
|
|
|
|
), |
|
|
|
|
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) |
|
|
|
|
, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if self.out_channels == channels: |
|
|
|
@ -810,13 +809,13 @@ class UNetModel(nn.Module):
|
|
|
|
|
self._feature_size += ch |
|
|
|
|
|
|
|
|
|
self.out = nn.Sequential( |
|
|
|
|
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), |
|
|
|
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device), |
|
|
|
|
nn.SiLU(), |
|
|
|
|
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), |
|
|
|
|
) |
|
|
|
|
if self.predict_codebook_ids: |
|
|
|
|
self.id_predictor = nn.Sequential( |
|
|
|
|
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), |
|
|
|
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device), |
|
|
|
|
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), |
|
|
|
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits |
|
|
|
|
) |
|
|
|
|