|
|
|
@ -53,14 +53,27 @@ def _summarize_chunk(
|
|
|
|
|
key_t: Tensor, |
|
|
|
|
value: Tensor, |
|
|
|
|
scale: float, |
|
|
|
|
upcast_attention: bool, |
|
|
|
|
) -> AttnChunk: |
|
|
|
|
attn_weights = torch.baddbmm( |
|
|
|
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
|
|
|
query, |
|
|
|
|
key_t, |
|
|
|
|
alpha=scale, |
|
|
|
|
beta=0, |
|
|
|
|
) |
|
|
|
|
if upcast_attention: |
|
|
|
|
with torch.autocast(enabled=False, device_type = 'cuda'): |
|
|
|
|
query = query.float() |
|
|
|
|
key_t = key_t.float() |
|
|
|
|
attn_weights = torch.baddbmm( |
|
|
|
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
|
|
|
query, |
|
|
|
|
key_t, |
|
|
|
|
alpha=scale, |
|
|
|
|
beta=0, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
attn_weights = torch.baddbmm( |
|
|
|
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
|
|
|
query, |
|
|
|
|
key_t, |
|
|
|
|
alpha=scale, |
|
|
|
|
beta=0, |
|
|
|
|
) |
|
|
|
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True) |
|
|
|
|
max_score = max_score.detach() |
|
|
|
|
exp_weights = torch.exp(attn_weights - max_score) |
|
|
|
@ -112,14 +125,27 @@ def _get_attention_scores_no_kv_chunking(
|
|
|
|
|
key_t: Tensor, |
|
|
|
|
value: Tensor, |
|
|
|
|
scale: float, |
|
|
|
|
upcast_attention: bool, |
|
|
|
|
) -> Tensor: |
|
|
|
|
attn_scores = torch.baddbmm( |
|
|
|
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
|
|
|
query, |
|
|
|
|
key_t, |
|
|
|
|
alpha=scale, |
|
|
|
|
beta=0, |
|
|
|
|
) |
|
|
|
|
if upcast_attention: |
|
|
|
|
with torch.autocast(enabled=False, device_type = 'cuda'): |
|
|
|
|
query = query.float() |
|
|
|
|
key_t = key_t.float() |
|
|
|
|
attn_scores = torch.baddbmm( |
|
|
|
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
|
|
|
query, |
|
|
|
|
key_t, |
|
|
|
|
alpha=scale, |
|
|
|
|
beta=0, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
attn_scores = torch.baddbmm( |
|
|
|
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
|
|
|
query, |
|
|
|
|
key_t, |
|
|
|
|
alpha=scale, |
|
|
|
|
beta=0, |
|
|
|
|
) |
|
|
|
|
attn_probs = attn_scores.softmax(dim=-1) |
|
|
|
|
del attn_scores |
|
|
|
|
hidden_states_slice = torch.bmm(attn_probs, value) |
|
|
|
@ -137,6 +163,7 @@ def efficient_dot_product_attention(
|
|
|
|
|
kv_chunk_size: Optional[int] = None, |
|
|
|
|
kv_chunk_size_min: Optional[int] = None, |
|
|
|
|
use_checkpoint=True, |
|
|
|
|
upcast_attention=False, |
|
|
|
|
): |
|
|
|
|
"""Computes efficient dot-product attention given query, transposed key, and value. |
|
|
|
|
This is efficient version of attention presented in |
|
|
|
@ -170,11 +197,12 @@ def efficient_dot_product_attention(
|
|
|
|
|
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) |
|
|
|
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) |
|
|
|
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk |
|
|
|
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( |
|
|
|
|
_get_attention_scores_no_kv_chunking, |
|
|
|
|
scale=scale |
|
|
|
|
scale=scale, |
|
|
|
|
upcast_attention=upcast_attention |
|
|
|
|
) if k_tokens <= kv_chunk_size else ( |
|
|
|
|
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) |
|
|
|
|
partial( |
|
|
|
|