|
|
|
@ -437,9 +437,6 @@ class UNetModel(nn.Module):
|
|
|
|
|
operations=ops, |
|
|
|
|
): |
|
|
|
|
super().__init__() |
|
|
|
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" |
|
|
|
|
if use_spatial_transformer: |
|
|
|
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' |
|
|
|
|
|
|
|
|
|
if context_dim is not None: |
|
|
|
|
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' |
|
|
|
@ -456,7 +453,6 @@ class UNetModel(nn.Module):
|
|
|
|
|
if num_head_channels == -1: |
|
|
|
|
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' |
|
|
|
|
|
|
|
|
|
self.image_size = image_size |
|
|
|
|
self.in_channels = in_channels |
|
|
|
|
self.model_channels = model_channels |
|
|
|
|
self.out_channels = out_channels |
|
|
|
|