diff --git a/README.md b/README.md index 312468a9..80de21bc 100644 --- a/README.md +++ b/README.md @@ -207,12 +207,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the ```embedding:embedding_filename.pt``` -## How to increase generation speed? - -On non Nvidia hardware you can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything. - -```--dont-upcast-attention``` - ## How to show high-quality previews? Use ```--preview-method auto``` to enable previews. diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 569c7938..2759f4e9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -51,7 +51,6 @@ cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") -parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index de66db4f..2515bac5 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -19,14 +19,6 @@ from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init -# CrossAttn precision handling -if args.dont_upcast_attention: - logging.info("disabling upcasting of attention") - _ATTN_PRECISION = None -else: - _ATTN_PRECISION = torch.float32 - - def exists(val): return val is not None @@ -386,10 +378,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False): class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) + self.attn_precision = attn_precision self.heads = heads self.dim_head = dim_head @@ -411,15 +404,15 @@ class CrossAttention(nn.Module): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) else: - out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) return self.to_out(out) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, - disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() self.ff_in = ff_in or inner_dim is not None @@ -434,7 +427,7 @@ class BasicTransformerBlock(nn.Module): self.disable_self_attn = disable_self_attn self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn + context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) if disable_temporal_crossattention: @@ -448,7 +441,7 @@ class BasicTransformerBlock(nn.Module): context_dim_attn2 = context_dim self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) @@ -588,7 +581,7 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None, device=None, operations=ops): + use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth @@ -606,7 +599,7 @@ class SpatialTransformer(nn.Module): self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations) + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) for d in range(depth)] ) if not use_linear: @@ -662,6 +655,7 @@ class SpatialVideoTransformer(SpatialTransformer): disable_self_attn=False, disable_temporal_crossattention=False, max_time_embed_period: int = 10000, + attn_precision=None, dtype=None, device=None, operations=ops ): super().__init__( @@ -674,6 +668,7 @@ class SpatialVideoTransformer(SpatialTransformer): context_dim=context_dim, use_linear=use_linear, disable_self_attn=disable_self_attn, + attn_precision=attn_precision, dtype=dtype, device=device, operations=operations ) self.time_depth = time_depth @@ -703,6 +698,7 @@ class SpatialVideoTransformer(SpatialTransformer): inner_dim=time_mix_inner_dim, disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, + attn_precision=attn_precision, dtype=dtype, device=device, operations=operations ) for _ in range(self.depth) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index d782eff3..1f5a4ded 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -431,6 +431,7 @@ class UNetModel(nn.Module): video_kernel_size=None, disable_temporal_crossattention=False, max_ddpm_temb_period=10000, + attn_precision=None, device=None, operations=ops, ): @@ -550,13 +551,14 @@ class UNetModel(nn.Module): disable_self_attn=disable_self_attn, disable_temporal_crossattention=disable_temporal_crossattention, max_time_embed_period=max_ddpm_temb_period, + attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations ) else: return SpatialTransformer( ch, num_heads, dim_head, depth=depth, context_dim=context_dim, disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations ) def get_resblock( diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b3b69e05..6ca32e8e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE): "use_temporal_attention": False, } + unet_extra_config = { + "num_heads": -1, + "num_head_channels": 64, + "attn_precision": torch.float32, + } + latent_format = latent_formats.SD15 def model_type(self, state_dict, prefix=""): @@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE): "use_temporal_resblock": True } + unet_extra_config = { + "num_heads": -1, + "num_head_channels": 64, + "attn_precision": torch.float32, + } + clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." latent_format = latent_formats.SD15