|
|
@ -83,7 +83,8 @@ def _summarize_chunk( |
|
|
|
) |
|
|
|
) |
|
|
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True) |
|
|
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True) |
|
|
|
max_score = max_score.detach() |
|
|
|
max_score = max_score.detach() |
|
|
|
torch.exp(attn_weights - max_score, out=attn_weights) |
|
|
|
attn_weights -= max_score |
|
|
|
|
|
|
|
torch.exp(attn_weights, out=attn_weights) |
|
|
|
exp_weights = attn_weights.to(value.dtype) |
|
|
|
exp_weights = attn_weights.to(value.dtype) |
|
|
|
exp_values = torch.bmm(exp_weights, value) |
|
|
|
exp_values = torch.bmm(exp_weights, value) |
|
|
|
max_score = max_score.squeeze(-1) |
|
|
|
max_score = max_score.squeeze(-1) |
|
|
|