diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index d51a2fae..de66db4f 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -22,9 +22,9 @@ ops = comfy.ops.disable_weight_init # CrossAttn precision handling if args.dont_upcast_attention: logging.info("disabling upcasting of attention") - _ATTN_PRECISION = "fp16" + _ATTN_PRECISION = None else: - _ATTN_PRECISION = "fp32" + _ATTN_PRECISION = torch.float32 def exists(val): @@ -85,7 +85,7 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None): +def attention_basic(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -101,7 +101,7 @@ def attention_basic(q, k, v, heads, mask=None): ) # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale @@ -135,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None): return out -def attention_sub_quad(query, key, value, heads, mask=None): +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None): b, _, dim_head = query.shape dim_head //= heads @@ -146,7 +146,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) dtype = query.dtype - upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 + upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 if upcast_attention: bytes_per_token = torch.finfo(torch.float32).bits//8 else: @@ -195,7 +195,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None): +def attention_split(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -214,10 +214,12 @@ def attention_split(q, k, v, heads, mask=None): mem_free_total = model_management.get_free_memory(q.device) - if _ATTN_PRECISION =="fp32": + if attn_precision == torch.float32: element_size = 4 + upcast = True else: element_size = q.element_size() + upcast = False gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size @@ -251,7 +253,7 @@ def attention_split(q, k, v, heads, mask=None): slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size - if _ATTN_PRECISION =="fp32": + if upcast: with torch.autocast(enabled=False, device_type = 'cuda'): s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: @@ -302,7 +304,7 @@ try: except: pass -def attention_xformers(q, k, v, heads, mask=None): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads if BROKEN_XFORMERS: @@ -334,7 +336,7 @@ def attention_xformers(q, k, v, heads, mask=None): ) return out -def attention_pytorch(q, k, v, heads, mask=None): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads q, k, v = map( @@ -409,9 +411,9 @@ class CrossAttention(nn.Module): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads) + out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION) else: - out = optimized_attention_masked(q, k, v, self.heads, mask) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION) return self.to_out(out)