diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 2759f4e9..b8ac9bc6 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -95,6 +95,11 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") +upcast = parser.add_mutually_exclusive_group() +upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.") +upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") + + vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 1d5cf0da..42653086 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -19,6 +19,14 @@ from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init + +def get_attn_precision(attn_precision): + if args.dont_upcast_attention: + return None + if attn_precision is None and args.force_upcast_attention: + return torch.float32 + return attn_precision + def exists(val): return val is not None @@ -78,6 +86,8 @@ def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) def attention_basic(q, k, v, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 @@ -128,6 +138,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = query.shape dim_head //= heads @@ -188,6 +200,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None) return hidden_states def attention_split(q, k, v, heads, mask=None, attn_precision=None): + attn_precision = get_attn_precision(attn_precision) + b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5