@ -10,7 +10,6 @@ from .diffusionmodules.util import checkpoint
from . sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
import comfy . ops
if model_management . xformers_enabled ( ) :
import xformers
@ -52,9 +51,9 @@ def init_(tensor):
# feedforward
class GEGLU ( nn . Module ) :
def __init__ ( self , dim_in , dim_out , dtype = None , device = None ) :
def __init__ ( self , dim_in , dim_out , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
self . proj = comfy . op s. Linear ( dim_in , dim_out * 2 , dtype = dtype , device = device )
self . proj = operation s. Linear ( dim_in , dim_out * 2 , dtype = dtype , device = device )
def forward ( self , x ) :
x , gate = self . proj ( x ) . chunk ( 2 , dim = - 1 )
@ -62,19 +61,19 @@ class GEGLU(nn.Module):
class FeedForward ( nn . Module ) :
def __init__ ( self , dim , dim_out = None , mult = 4 , glu = False , dropout = 0. , dtype = None , device = None ) :
def __init__ ( self , dim , dim_out = None , mult = 4 , glu = False , dropout = 0. , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
inner_dim = int ( dim * mult )
dim_out = default ( dim_out , dim )
project_in = nn . Sequential (
comfy . op s. Linear ( dim , inner_dim , dtype = dtype , device = device ) ,
operation s. Linear ( dim , inner_dim , dtype = dtype , device = device ) ,
nn . GELU ( )
) if not glu else GEGLU ( dim , inner_dim , dtype = dtype , device = device )
) if not glu else GEGLU ( dim , inner_dim , dtype = dtype , device = device , operations = operations )
self . net = nn . Sequential (
project_in ,
nn . Dropout ( dropout ) ,
comfy . op s. Linear ( inner_dim , dim_out , dtype = dtype , device = device )
operation s. Linear ( inner_dim , dim_out , dtype = dtype , device = device )
)
def forward ( self , x ) :
@ -148,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan ( nn . Module ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
inner_dim = dim_head * heads
context_dim = default ( context_dim , query_dim )
@ -156,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
self . scale = dim_head * * - 0.5
self . heads = heads
self . to_q = comfy . op s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_q = operation s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_out = nn . Sequential (
comfy . op s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) ,
operation s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) ,
nn . Dropout ( dropout )
)
@ -245,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx ( nn . Module ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
inner_dim = dim_head * heads
context_dim = default ( context_dim , query_dim )
@ -253,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
self . scale = dim_head * * - 0.5
self . heads = heads
self . to_q = comfy . op s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_q = operation s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_out = nn . Sequential (
comfy . op s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) ,
operation s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) ,
nn . Dropout ( dropout )
)
@ -343,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
return self . to_out ( r2 )
class CrossAttention ( nn . Module ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
inner_dim = dim_head * heads
context_dim = default ( context_dim , query_dim )
@ -351,12 +350,12 @@ class CrossAttention(nn.Module):
self . scale = dim_head * * - 0.5
self . heads = heads
self . to_q = comfy . op s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_q = operation s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_out = nn . Sequential (
comfy . op s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) ,
operation s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) ,
nn . Dropout ( dropout )
)
@ -399,7 +398,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention ( nn . Module ) :
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0 , dtype = None , device = None ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0 , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
print ( f " Setting up { self . __class__ . __name__ } . Query dim is { query_dim } , context_dim is { context_dim } and using "
f " { heads } heads. " )
@ -409,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self . heads = heads
self . dim_head = dim_head
self . to_q = comfy . op s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_q = operation s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_out = nn . Sequential ( comfy . op s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) , nn . Dropout ( dropout ) )
self . to_out = nn . Sequential ( operation s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) , nn . Dropout ( dropout ) )
self . attention_op : Optional [ Any ] = None
def forward ( self , x , context = None , value = None , mask = None ) :
@ -450,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self . to_out ( out )
class CrossAttentionPytorch ( nn . Module ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None ) :
def __init__ ( self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
inner_dim = dim_head * heads
context_dim = default ( context_dim , query_dim )
@ -458,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
self . heads = heads
self . dim_head = dim_head
self . to_q = comfy . op s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = comfy . op s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_q = operation s. Linear ( query_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_k = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_v = operation s. Linear ( context_dim , inner_dim , bias = False , dtype = dtype , device = device )
self . to_out = nn . Sequential ( comfy . op s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) , nn . Dropout ( dropout ) )
self . to_out = nn . Sequential ( operation s. Linear ( inner_dim , query_dim , dtype = dtype , device = device ) , nn . Dropout ( dropout ) )
self . attention_op : Optional [ Any ] = None
def forward ( self , x , context = None , value = None , mask = None ) :
@ -508,14 +507,14 @@ else:
class BasicTransformerBlock ( nn . Module ) :
def __init__ ( self , dim , n_heads , d_head , dropout = 0. , context_dim = None , gated_ff = True , checkpoint = True ,
disable_self_attn = False , dtype = None , device = None ) :
disable_self_attn = False , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
self . disable_self_attn = disable_self_attn
self . attn1 = CrossAttention ( query_dim = dim , heads = n_heads , dim_head = d_head , dropout = dropout ,
context_dim = context_dim if self . disable_self_attn else None , dtype = dtype , device = device ) # is a self-attention if not self.disable_self_attn
self . ff = FeedForward ( dim , dropout = dropout , glu = gated_ff , dtype = dtype , device = device )
context_dim = context_dim if self . disable_self_attn else None , dtype = dtype , device = device , operations = operations ) # is a self-attention if not self.disable_self_attn
self . ff = FeedForward ( dim , dropout = dropout , glu = gated_ff , dtype = dtype , device = device , operations = operations )
self . attn2 = CrossAttention ( query_dim = dim , context_dim = context_dim ,
heads = n_heads , dim_head = d_head , dropout = dropout , dtype = dtype , device = device ) # is self-attn if context is none
heads = n_heads , dim_head = d_head , dropout = dropout , dtype = dtype , device = device , operations = operations ) # is self-attn if context is none
self . norm1 = nn . LayerNorm ( dim , dtype = dtype , device = device )
self . norm2 = nn . LayerNorm ( dim , dtype = dtype , device = device )
self . norm3 = nn . LayerNorm ( dim , dtype = dtype , device = device )
@ -648,7 +647,7 @@ class SpatialTransformer(nn.Module):
def __init__ ( self , in_channels , n_heads , d_head ,
depth = 1 , dropout = 0. , context_dim = None ,
disable_self_attn = False , use_linear = False ,
use_checkpoint = True , dtype = None , device = None ) :
use_checkpoint = True , dtype = None , device = None , operations = None ) :
super ( ) . __init__ ( )
if exists ( context_dim ) and not isinstance ( context_dim , list ) :
context_dim = [ context_dim ] * depth
@ -656,26 +655,26 @@ class SpatialTransformer(nn.Module):
inner_dim = n_heads * d_head
self . norm = Normalize ( in_channels , dtype = dtype , device = device )
if not use_linear :
self . proj_in = nn . Conv2d ( in_channels ,
self . proj_in = operations . Conv2d ( in_channels ,
inner_dim ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 , dtype = dtype , device = device )
else :
self . proj_in = comfy . op s. Linear ( in_channels , inner_dim , dtype = dtype , device = device )
self . proj_in = operation s. Linear ( in_channels , inner_dim , dtype = dtype , device = device )
self . transformer_blocks = nn . ModuleList (
[ BasicTransformerBlock ( inner_dim , n_heads , d_head , dropout = dropout , context_dim = context_dim [ d ] ,
disable_self_attn = disable_self_attn , checkpoint = use_checkpoint , dtype = dtype , device = device )
disable_self_attn = disable_self_attn , checkpoint = use_checkpoint , dtype = dtype , device = device , operations = operations )
for d in range ( depth ) ]
)
if not use_linear :
self . proj_out = nn . Conv2d ( inner_dim , in_channels ,
self . proj_out = operations . Conv2d ( inner_dim , in_channels ,
kernel_size = 1 ,
stride = 1 ,
padding = 0 , dtype = dtype , device = device )
else :
self . proj_out = comfy . op s. Linear ( in_channels , inner_dim , dtype = dtype , device = device )
self . proj_out = operation s. Linear ( in_channels , inner_dim , dtype = dtype , device = device )
self . use_linear = use_linear
def forward ( self , x , context = None , transformer_options = { } ) :