diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index aa74b632..74a2fd99 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -625,7 +625,7 @@ class SpatialTransformer(nn.Module): x = self.norm(x) if not self.use_linear: x = self.proj_in(x) - x = x.movedim(1, -1).flatten(1, 2).contiguous() + x = x.movedim(1, 3).flatten(1, 2).contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): @@ -633,7 +633,7 @@ class SpatialTransformer(nn.Module): x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) - x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(-1, 1).contiguous() + x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in