@ -299,6 +299,64 @@ class MemoryEfficientAttnBlock(nn.Module):
out = self . proj_out ( out )
return x + out
class MemoryEfficientAttnBlockPytorch ( nn . Module ) :
def __init__ ( self , in_channels ) :
super ( ) . __init__ ( )
self . in_channels = in_channels
self . norm = Normalize ( in_channels )
self . q = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . k = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . v = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . proj_out = torch . nn . Conv2d ( in_channels ,
in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 )
self . attention_op : Optional [ Any ] = None
def forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
q = self . q ( h_ )
k = self . k ( h_ )
v = self . v ( h_ )
# compute attention
B , C , H , W = q . shape
q , k , v = map ( lambda x : rearrange ( x , ' b c h w -> b (h w) c ' ) , ( q , k , v ) )
q , k , v = map (
lambda t : t . unsqueeze ( 3 )
. reshape ( B , t . shape [ 1 ] , 1 , C )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( B * 1 , t . shape [ 1 ] , C )
. contiguous ( ) ,
( q , k , v ) ,
)
out = torch . nn . functional . scaled_dot_product_attention ( q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = False )
out = (
out . unsqueeze ( 0 )
. reshape ( B , 1 , out . shape [ 1 ] , C )
. permute ( 0 , 2 , 1 , 3 )
. reshape ( B , out . shape [ 1 ] , C )
)
out = rearrange ( out , ' b (h w) c -> b c h w ' , b = B , h = H , w = W , c = C )
out = self . proj_out ( out )
return x + out
class MemoryEfficientCrossAttentionWrapper ( MemoryEfficientCrossAttention ) :
def forward ( self , x , context = None , mask = None ) :
@ -313,6 +371,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in [ " vanilla " , " vanilla-xformers " , " memory-efficient-cross-attn " , " linear " , " none " ] , f ' attn_type { attn_type } unknown '
if model_management . xformers_enabled ( ) and attn_type == " vanilla " :
attn_type = " vanilla-xformers "
if model_management . pytorch_attention_enabled ( ) and attn_type == " vanilla " :
attn_type = " vanilla-pytorch "
print ( f " making attention of type ' { attn_type } ' with { in_channels } in_channels " )
if attn_type == " vanilla " :
assert attn_kwargs is None
@ -320,6 +380,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
elif attn_type == " vanilla-xformers " :
print ( f " building MemoryEfficientAttnBlock with { in_channels } in_channels... " )
return MemoryEfficientAttnBlock ( in_channels )
elif attn_type == " vanilla-pytorch " :
return MemoryEfficientAttnBlockPytorch ( in_channels )
elif type == " memory-efficient-cross-attn " :
attn_kwargs [ " query_dim " ] = in_channels
return MemoryEfficientCrossAttentionWrapper ( * * attn_kwargs )