|
|
|
@ -6,7 +6,7 @@ from einops import rearrange, repeat
|
|
|
|
|
from typing import Optional, Any |
|
|
|
|
import logging |
|
|
|
|
|
|
|
|
|
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding |
|
|
|
|
from .diffusionmodules.util import AlphaBlender, timestep_embedding |
|
|
|
|
from .sub_quadratic_attention import efficient_dot_product_attention |
|
|
|
|
|
|
|
|
|
from comfy import model_management |
|
|
|
@ -454,15 +454,11 @@ class BasicTransformerBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) |
|
|
|
|
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) |
|
|
|
|
self.checkpoint = checkpoint |
|
|
|
|
self.n_heads = n_heads |
|
|
|
|
self.d_head = d_head |
|
|
|
|
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa |
|
|
|
|
|
|
|
|
|
def forward(self, x, context=None, transformer_options={}): |
|
|
|
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) |
|
|
|
|
|
|
|
|
|
def _forward(self, x, context=None, transformer_options={}): |
|
|
|
|
extra_options = {} |
|
|
|
|
block = transformer_options.get("block", None) |
|
|
|
|
block_index = transformer_options.get("block_index", 0) |
|
|
|
@ -629,7 +625,7 @@ class SpatialTransformer(nn.Module):
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
if not self.use_linear: |
|
|
|
|
x = self.proj_in(x) |
|
|
|
|
x = rearrange(x, 'b c h w -> b (h w) c').contiguous() |
|
|
|
|
x = x.movedim(1, -1).flatten(1, 2).contiguous() |
|
|
|
|
if self.use_linear: |
|
|
|
|
x = self.proj_in(x) |
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
|
@ -637,7 +633,7 @@ class SpatialTransformer(nn.Module):
|
|
|
|
|
x = block(x, context=context[i], transformer_options=transformer_options) |
|
|
|
|
if self.use_linear: |
|
|
|
|
x = self.proj_out(x) |
|
|
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() |
|
|
|
|
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(-1, 1).contiguous() |
|
|
|
|
if not self.use_linear: |
|
|
|
|
x = self.proj_out(x) |
|
|
|
|
return x + x_in |
|
|
|
|