|
|
|
@ -193,40 +193,7 @@ def slice_attention(q, k, v):
|
|
|
|
|
|
|
|
|
|
return r1 |
|
|
|
|
|
|
|
|
|
class AttnBlock(nn.Module): |
|
|
|
|
def __init__(self, in_channels): |
|
|
|
|
super().__init__() |
|
|
|
|
self.in_channels = in_channels |
|
|
|
|
|
|
|
|
|
self.norm = Normalize(in_channels) |
|
|
|
|
self.q = comfy.ops.Conv2d(in_channels, |
|
|
|
|
in_channels, |
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
self.k = comfy.ops.Conv2d(in_channels, |
|
|
|
|
in_channels, |
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
self.v = comfy.ops.Conv2d(in_channels, |
|
|
|
|
in_channels, |
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
self.proj_out = comfy.ops.Conv2d(in_channels, |
|
|
|
|
in_channels, |
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
h_ = x |
|
|
|
|
h_ = self.norm(h_) |
|
|
|
|
q = self.q(h_) |
|
|
|
|
k = self.k(h_) |
|
|
|
|
v = self.v(h_) |
|
|
|
|
|
|
|
|
|
def normal_attention(q, k, v): |
|
|
|
|
# compute attention |
|
|
|
|
b,c,h,w = q.shape |
|
|
|
|
|
|
|
|
@ -238,51 +205,9 @@ class AttnBlock(nn.Module):
|
|
|
|
|
r1 = slice_attention(q, k, v) |
|
|
|
|
h_ = r1.reshape(b,c,h,w) |
|
|
|
|
del r1 |
|
|
|
|
h_ = self.proj_out(h_) |
|
|
|
|
|
|
|
|
|
return x+h_ |
|
|
|
|
|
|
|
|
|
class MemoryEfficientAttnBlock(nn.Module): |
|
|
|
|
""" |
|
|
|
|
Uses xformers efficient implementation, |
|
|
|
|
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 |
|
|
|
|
Note: this is a single-head self-attention operation |
|
|
|
|
""" |
|
|
|
|
# |
|
|
|
|
def __init__(self, in_channels): |
|
|
|
|
super().__init__() |
|
|
|
|
self.in_channels = in_channels |
|
|
|
|
|
|
|
|
|
self.norm = Normalize(in_channels) |
|
|
|
|
self.q = comfy.ops.Conv2d(in_channels, |
|
|
|
|
in_channels, |
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
self.k = comfy.ops.Conv2d(in_channels, |
|
|
|
|
in_channels, |
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
self.v = comfy.ops.Conv2d(in_channels, |
|
|
|
|
in_channels, |
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
self.proj_out = comfy.ops.Conv2d(in_channels, |
|
|
|
|
in_channels, |
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
self.attention_op: Optional[Any] = None |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
h_ = x |
|
|
|
|
h_ = self.norm(h_) |
|
|
|
|
q = self.q(h_) |
|
|
|
|
k = self.k(h_) |
|
|
|
|
v = self.v(h_) |
|
|
|
|
return h_ |
|
|
|
|
|
|
|
|
|
def xformers_attention(q, k, v): |
|
|
|
|
# compute attention |
|
|
|
|
B, C, H, W = q.shape |
|
|
|
|
q, k, v = map( |
|
|
|
@ -291,15 +216,30 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) |
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) |
|
|
|
|
out = out.transpose(1, 2).reshape(B, C, H, W) |
|
|
|
|
except NotImplementedError as e: |
|
|
|
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def pytorch_attention(q, k, v): |
|
|
|
|
# compute attention |
|
|
|
|
B, C, H, W = q.shape |
|
|
|
|
q, k, v = map( |
|
|
|
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), |
|
|
|
|
(q, k, v), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) |
|
|
|
|
out = out.transpose(2, 3).reshape(B, C, H, W) |
|
|
|
|
except model_management.OOM_EXCEPTION as e: |
|
|
|
|
print("scaled_dot_product_attention OOMed: switched to slice attention") |
|
|
|
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
out = self.proj_out(out) |
|
|
|
|
return x+out |
|
|
|
|
|
|
|
|
|
class MemoryEfficientAttnBlockPytorch(nn.Module): |
|
|
|
|
class AttnBlock(nn.Module): |
|
|
|
|
def __init__(self, in_channels): |
|
|
|
|
super().__init__() |
|
|
|
|
self.in_channels = in_channels |
|
|
|
@ -325,7 +265,16 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|
|
|
|
kernel_size=1, |
|
|
|
|
stride=1, |
|
|
|
|
padding=0) |
|
|
|
|
self.attention_op: Optional[Any] = None |
|
|
|
|
|
|
|
|
|
if model_management.xformers_enabled_vae(): |
|
|
|
|
print("Using xformers attention in VAE") |
|
|
|
|
self.optimized_attention = xformers_attention |
|
|
|
|
elif model_management.pytorch_attention_enabled(): |
|
|
|
|
print("Using pytorch attention in VAE") |
|
|
|
|
self.optimized_attention = pytorch_attention |
|
|
|
|
else: |
|
|
|
|
print("Using split attention in VAE") |
|
|
|
|
self.optimized_attention = normal_attention |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
h_ = x |
|
|
|
@ -334,42 +283,15 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|
|
|
|
k = self.k(h_) |
|
|
|
|
v = self.v(h_) |
|
|
|
|
|
|
|
|
|
# compute attention |
|
|
|
|
B, C, H, W = q.shape |
|
|
|
|
q, k, v = map( |
|
|
|
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), |
|
|
|
|
(q, k, v), |
|
|
|
|
) |
|
|
|
|
h_ = self.optimized_attention(q, k, v) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) |
|
|
|
|
out = out.transpose(2, 3).reshape(B, C, H, W) |
|
|
|
|
except model_management.OOM_EXCEPTION as e: |
|
|
|
|
print("scaled_dot_product_attention OOMed: switched to slice attention") |
|
|
|
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) |
|
|
|
|
h_ = self.proj_out(h_) |
|
|
|
|
|
|
|
|
|
return x+h_ |
|
|
|
|
|
|
|
|
|
out = self.proj_out(out) |
|
|
|
|
return x+out |
|
|
|
|
|
|
|
|
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): |
|
|
|
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' |
|
|
|
|
if model_management.xformers_enabled_vae() and attn_type == "vanilla": |
|
|
|
|
attn_type = "vanilla-xformers" |
|
|
|
|
elif model_management.pytorch_attention_enabled() and attn_type == "vanilla": |
|
|
|
|
attn_type = "vanilla-pytorch" |
|
|
|
|
print(f"making attention of type '{attn_type}' with {in_channels} in_channels") |
|
|
|
|
if attn_type == "vanilla": |
|
|
|
|
assert attn_kwargs is None |
|
|
|
|
return AttnBlock(in_channels) |
|
|
|
|
elif attn_type == "vanilla-xformers": |
|
|
|
|
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") |
|
|
|
|
return MemoryEfficientAttnBlock(in_channels) |
|
|
|
|
elif attn_type == "vanilla-pytorch": |
|
|
|
|
return MemoryEfficientAttnBlockPytorch(in_channels) |
|
|
|
|
elif attn_type == "none": |
|
|
|
|
return nn.Identity(in_channels) |
|
|
|
|
else: |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
|
|