|
|
|
@ -12,8 +12,6 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|
|
|
|
from comfy import model_management |
|
|
|
|
import comfy.ops |
|
|
|
|
|
|
|
|
|
from . import tomesd |
|
|
|
|
|
|
|
|
|
if model_management.xformers_enabled(): |
|
|
|
|
import xformers |
|
|
|
|
import xformers.ops |
|
|
|
@ -519,23 +517,39 @@ class BasicTransformerBlock(nn.Module):
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype) |
|
|
|
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype) |
|
|
|
|
self.checkpoint = checkpoint |
|
|
|
|
self.n_heads = n_heads |
|
|
|
|
self.d_head = d_head |
|
|
|
|
|
|
|
|
|
def forward(self, x, context=None, transformer_options={}): |
|
|
|
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) |
|
|
|
|
|
|
|
|
|
def _forward(self, x, context=None, transformer_options={}): |
|
|
|
|
extra_options = {} |
|
|
|
|
block = None |
|
|
|
|
block_index = 0 |
|
|
|
|
if "current_index" in transformer_options: |
|
|
|
|
extra_options["transformer_index"] = transformer_options["current_index"] |
|
|
|
|
if "block_index" in transformer_options: |
|
|
|
|
extra_options["block_index"] = transformer_options["block_index"] |
|
|
|
|
block_index = transformer_options["block_index"] |
|
|
|
|
extra_options["block_index"] = block_index |
|
|
|
|
if "original_shape" in transformer_options: |
|
|
|
|
extra_options["original_shape"] = transformer_options["original_shape"] |
|
|
|
|
if "block" in transformer_options: |
|
|
|
|
block = transformer_options["block"] |
|
|
|
|
extra_options["block"] = block |
|
|
|
|
if "patches" in transformer_options: |
|
|
|
|
transformer_patches = transformer_options["patches"] |
|
|
|
|
else: |
|
|
|
|
transformer_patches = {} |
|
|
|
|
|
|
|
|
|
extra_options["n_heads"] = self.n_heads |
|
|
|
|
extra_options["dim_head"] = self.d_head |
|
|
|
|
|
|
|
|
|
if "patches_replace" in transformer_options: |
|
|
|
|
transformer_patches_replace = transformer_options["patches_replace"] |
|
|
|
|
else: |
|
|
|
|
transformer_patches_replace = {} |
|
|
|
|
|
|
|
|
|
n = self.norm1(x) |
|
|
|
|
if self.disable_self_attn: |
|
|
|
|
context_attn1 = context |
|
|
|
@ -551,12 +565,29 @@ class BasicTransformerBlock(nn.Module):
|
|
|
|
|
for p in patch: |
|
|
|
|
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) |
|
|
|
|
|
|
|
|
|
if "tomesd" in transformer_options: |
|
|
|
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) |
|
|
|
|
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) |
|
|
|
|
transformer_block = (block[0], block[1], block_index) |
|
|
|
|
attn1_replace_patch = transformer_patches_replace.get("attn1", {}) |
|
|
|
|
block_attn1 = transformer_block |
|
|
|
|
if block_attn1 not in attn1_replace_patch: |
|
|
|
|
block_attn1 = block |
|
|
|
|
|
|
|
|
|
if block_attn1 in attn1_replace_patch: |
|
|
|
|
if context_attn1 is None: |
|
|
|
|
context_attn1 = n |
|
|
|
|
value_attn1 = n |
|
|
|
|
n = self.attn1.to_q(n) |
|
|
|
|
context_attn1 = self.attn1.to_k(context_attn1) |
|
|
|
|
value_attn1 = self.attn1.to_v(value_attn1) |
|
|
|
|
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) |
|
|
|
|
n = self.attn1.to_out(n) |
|
|
|
|
else: |
|
|
|
|
n = self.attn1(n, context=context_attn1, value=value_attn1) |
|
|
|
|
|
|
|
|
|
if "attn1_output_patch" in transformer_patches: |
|
|
|
|
patch = transformer_patches["attn1_output_patch"] |
|
|
|
|
for p in patch: |
|
|
|
|
n = p(n, extra_options) |
|
|
|
|
|
|
|
|
|
x += n |
|
|
|
|
if "middle_patch" in transformer_patches: |
|
|
|
|
patch = transformer_patches["middle_patch"] |
|
|
|
@ -573,6 +604,20 @@ class BasicTransformerBlock(nn.Module):
|
|
|
|
|
for p in patch: |
|
|
|
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) |
|
|
|
|
|
|
|
|
|
attn2_replace_patch = transformer_patches_replace.get("attn2", {}) |
|
|
|
|
block_attn2 = transformer_block |
|
|
|
|
if block_attn2 not in attn2_replace_patch: |
|
|
|
|
block_attn2 = block |
|
|
|
|
|
|
|
|
|
if block_attn2 in attn2_replace_patch: |
|
|
|
|
if value_attn2 is None: |
|
|
|
|
value_attn2 = context_attn2 |
|
|
|
|
n = self.attn2.to_q(n) |
|
|
|
|
context_attn2 = self.attn2.to_k(context_attn2) |
|
|
|
|
value_attn2 = self.attn2.to_v(value_attn2) |
|
|
|
|
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) |
|
|
|
|
n = self.attn2.to_out(n) |
|
|
|
|
else: |
|
|
|
|
n = self.attn2(n, context=context_attn2, value=value_attn2) |
|
|
|
|
|
|
|
|
|
if "attn2_output_patch" in transformer_patches: |
|
|
|
|