Browse Source

Refactor attention upcasting code part 1.

pull/2569/head
comfyanonymous 6 months ago
parent
commit
b0ab31d06c
  1. 28
      comfy/ldm/modules/attention.py

28
comfy/ldm/modules/attention.py

@ -22,9 +22,9 @@ ops = comfy.ops.disable_weight_init
# CrossAttn precision handling # CrossAttn precision handling
if args.dont_upcast_attention: if args.dont_upcast_attention:
logging.info("disabling upcasting of attention") logging.info("disabling upcasting of attention")
_ATTN_PRECISION = "fp16" _ATTN_PRECISION = None
else: else:
_ATTN_PRECISION = "fp32" _ATTN_PRECISION = torch.float32
def exists(val): def exists(val):
@ -85,7 +85,7 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None): def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
@ -101,7 +101,7 @@ def attention_basic(q, k, v, heads, mask=None):
) )
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else: else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -135,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None):
return out return out
def attention_sub_quad(query, key, value, heads, mask=None): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
b, _, dim_head = query.shape b, _, dim_head = query.shape
dim_head //= heads dim_head //= heads
@ -146,7 +146,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
if upcast_attention: if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8 bytes_per_token = torch.finfo(torch.float32).bits//8
else: else:
@ -195,7 +195,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None): def attention_split(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = dim_head ** -0.5
@ -214,10 +214,12 @@ def attention_split(q, k, v, heads, mask=None):
mem_free_total = model_management.get_free_memory(q.device) mem_free_total = model_management.get_free_memory(q.device)
if _ATTN_PRECISION =="fp32": if attn_precision == torch.float32:
element_size = 4 element_size = 4
upcast = True
else: else:
element_size = q.element_size() element_size = q.element_size()
upcast = False
gb = 1024 ** 3 gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
@ -251,7 +253,7 @@ def attention_split(q, k, v, heads, mask=None):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
if _ATTN_PRECISION =="fp32": if upcast:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else: else:
@ -302,7 +304,7 @@ try:
except: except:
pass pass
def attention_xformers(q, k, v, heads, mask=None): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
if BROKEN_XFORMERS: if BROKEN_XFORMERS:
@ -334,7 +336,7 @@ def attention_xformers(q, k, v, heads, mask=None):
) )
return out return out
def attention_pytorch(q, k, v, heads, mask=None): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
q, k, v = map( q, k, v = map(
@ -409,9 +411,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context) v = self.to_v(context)
if mask is None: if mask is None:
out = optimized_attention(q, k, v, self.heads) out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION)
else: else:
out = optimized_attention_masked(q, k, v, self.heads, mask) out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION)
return self.to_out(out) return self.to_out(out)

Loading…
Cancel
Save