From 6c875d846b54ee34eb032fdb790a2fd1621d18f4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 07:46:17 -0500 Subject: [PATCH] Fix clip attention mask issues on some hardware. --- comfy/clip_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 09e7bbca..9ba4e039 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module): x = self.embeddings(input_tokens) mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)