|
|
|
@ -146,8 +146,17 @@ def _get_attention_scores_no_kv_chunking(
|
|
|
|
|
alpha=scale, |
|
|
|
|
beta=0, |
|
|
|
|
) |
|
|
|
|
attn_probs = attn_scores.softmax(dim=-1) |
|
|
|
|
del attn_scores |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
attn_probs = attn_scores.softmax(dim=-1) |
|
|
|
|
del attn_scores |
|
|
|
|
except torch.cuda.OutOfMemoryError: |
|
|
|
|
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") |
|
|
|
|
torch.exp(attn_scores, out=attn_scores) |
|
|
|
|
summed = torch.sum(attn_scores, dim=-1, keepdim=True) |
|
|
|
|
attn_scores /= summed |
|
|
|
|
attn_probs = attn_scores |
|
|
|
|
|
|
|
|
|
hidden_states_slice = torch.bmm(attn_probs, value) |
|
|
|
|
return hidden_states_slice |
|
|
|
|
|
|
|
|
|