diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 42653086..88ee2f32 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -318,11 +318,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): return attention_pytorch(q, k, v, heads, mask) q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, -1, heads, dim_head) - .permute(0, 2, 1, 3) - .reshape(b * heads, -1, dim_head) - .contiguous(), + lambda t: t.reshape(b, -1, heads, dim_head), (q, k, v), ) @@ -335,10 +331,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = ( - out.unsqueeze(0) - .reshape(b, heads, -1, dim_head) - .permute(0, 2, 1, 3) - .reshape(b, -1, heads * dim_head) + out.reshape(b, -1, heads * dim_head) ) return out diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fabc5c5e..04eb83b2 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -3,7 +3,6 @@ import math import torch import torch.nn as nn import numpy as np -from einops import rearrange from typing import Optional, Any import logging