You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
241 lines
8.3 KiB
241 lines
8.3 KiB
# original source: |
|
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py |
|
# license: |
|
# MIT |
|
# credit: |
|
# Amin Rezaei (original author) |
|
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) |
|
# implementation of: |
|
# Self-attention Does Not Need O(n2) Memory": |
|
# https://arxiv.org/abs/2112.05682v2 |
|
|
|
from functools import partial |
|
import torch |
|
from torch import Tensor |
|
from torch.utils.checkpoint import checkpoint |
|
import math |
|
from typing import Optional, NamedTuple, Protocol, List |
|
|
|
from torch import Tensor |
|
from typing import List |
|
|
|
def dynamic_slice( |
|
x: Tensor, |
|
starts: List[int], |
|
sizes: List[int], |
|
) -> Tensor: |
|
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] |
|
return x[slicing] |
|
|
|
class AttnChunk(NamedTuple): |
|
exp_values: Tensor |
|
exp_weights_sum: Tensor |
|
max_score: Tensor |
|
|
|
class SummarizeChunk(Protocol): |
|
@staticmethod |
|
def __call__( |
|
query: Tensor, |
|
key_t: Tensor, |
|
value: Tensor, |
|
) -> AttnChunk: ... |
|
|
|
class ComputeQueryChunkAttn(Protocol): |
|
@staticmethod |
|
def __call__( |
|
query: Tensor, |
|
key_t: Tensor, |
|
value: Tensor, |
|
) -> Tensor: ... |
|
|
|
def _summarize_chunk( |
|
query: Tensor, |
|
key_t: Tensor, |
|
value: Tensor, |
|
scale: float, |
|
upcast_attention: bool, |
|
) -> AttnChunk: |
|
if upcast_attention: |
|
with torch.autocast(enabled=False, device_type = 'cuda'): |
|
query = query.float() |
|
key_t = key_t.float() |
|
attn_weights = torch.baddbmm( |
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
query, |
|
key_t, |
|
alpha=scale, |
|
beta=0, |
|
) |
|
else: |
|
attn_weights = torch.baddbmm( |
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
query, |
|
key_t, |
|
alpha=scale, |
|
beta=0, |
|
) |
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True) |
|
max_score = max_score.detach() |
|
exp_weights = torch.exp(attn_weights - max_score) |
|
exp_values = torch.bmm(exp_weights, value) |
|
max_score = max_score.squeeze(-1) |
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) |
|
|
|
def _query_chunk_attention( |
|
query: Tensor, |
|
key_t: Tensor, |
|
value: Tensor, |
|
summarize_chunk: SummarizeChunk, |
|
kv_chunk_size: int, |
|
) -> Tensor: |
|
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape |
|
_, _, v_channels_per_head = value.shape |
|
|
|
def chunk_scanner(chunk_idx: int) -> AttnChunk: |
|
key_chunk = dynamic_slice( |
|
key_t, |
|
(0, 0, chunk_idx), |
|
(batch_x_heads, k_channels_per_head, kv_chunk_size) |
|
) |
|
value_chunk = dynamic_slice( |
|
value, |
|
(0, chunk_idx, 0), |
|
(batch_x_heads, kv_chunk_size, v_channels_per_head) |
|
) |
|
return summarize_chunk(query, key_chunk, value_chunk) |
|
|
|
chunks: List[AttnChunk] = [ |
|
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) |
|
] |
|
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) |
|
chunk_values, chunk_weights, chunk_max = acc_chunk |
|
|
|
global_max, _ = torch.max(chunk_max, 0, keepdim=True) |
|
max_diffs = torch.exp(chunk_max - global_max) |
|
chunk_values *= torch.unsqueeze(max_diffs, -1) |
|
chunk_weights *= max_diffs |
|
|
|
all_values = chunk_values.sum(dim=0) |
|
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) |
|
return all_values / all_weights |
|
|
|
# TODO: refactor CrossAttention#get_attention_scores to share code with this |
|
def _get_attention_scores_no_kv_chunking( |
|
query: Tensor, |
|
key_t: Tensor, |
|
value: Tensor, |
|
scale: float, |
|
upcast_attention: bool, |
|
) -> Tensor: |
|
if upcast_attention: |
|
with torch.autocast(enabled=False, device_type = 'cuda'): |
|
query = query.float() |
|
key_t = key_t.float() |
|
attn_scores = torch.baddbmm( |
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
query, |
|
key_t, |
|
alpha=scale, |
|
beta=0, |
|
) |
|
else: |
|
attn_scores = torch.baddbmm( |
|
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
|
query, |
|
key_t, |
|
alpha=scale, |
|
beta=0, |
|
) |
|
|
|
try: |
|
attn_probs = attn_scores.softmax(dim=-1) |
|
del attn_scores |
|
except torch.cuda.OutOfMemoryError: |
|
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") |
|
torch.exp(attn_scores, out=attn_scores) |
|
summed = torch.sum(attn_scores, dim=-1, keepdim=True) |
|
attn_scores /= summed |
|
attn_probs = attn_scores |
|
|
|
hidden_states_slice = torch.bmm(attn_probs, value) |
|
return hidden_states_slice |
|
|
|
class ScannedChunk(NamedTuple): |
|
chunk_idx: int |
|
attn_chunk: AttnChunk |
|
|
|
def efficient_dot_product_attention( |
|
query: Tensor, |
|
key_t: Tensor, |
|
value: Tensor, |
|
query_chunk_size=1024, |
|
kv_chunk_size: Optional[int] = None, |
|
kv_chunk_size_min: Optional[int] = None, |
|
use_checkpoint=True, |
|
upcast_attention=False, |
|
): |
|
"""Computes efficient dot-product attention given query, transposed key, and value. |
|
This is efficient version of attention presented in |
|
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. |
|
Args: |
|
query: queries for calculating attention with shape of |
|
`[batch * num_heads, tokens, channels_per_head]`. |
|
key_t: keys for calculating attention with shape of |
|
`[batch * num_heads, channels_per_head, tokens]`. |
|
value: values to be used in attention with shape of |
|
`[batch * num_heads, tokens, channels_per_head]`. |
|
query_chunk_size: int: query chunks size |
|
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) |
|
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). |
|
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) |
|
Returns: |
|
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. |
|
""" |
|
batch_x_heads, q_tokens, q_channels_per_head = query.shape |
|
_, _, k_tokens = key_t.shape |
|
scale = q_channels_per_head ** -0.5 |
|
|
|
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) |
|
if kv_chunk_size_min is not None: |
|
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) |
|
|
|
def get_query_chunk(chunk_idx: int) -> Tensor: |
|
return dynamic_slice( |
|
query, |
|
(0, chunk_idx, 0), |
|
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) |
|
) |
|
|
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) |
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk |
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( |
|
_get_attention_scores_no_kv_chunking, |
|
scale=scale, |
|
upcast_attention=upcast_attention |
|
) if k_tokens <= kv_chunk_size else ( |
|
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) |
|
partial( |
|
_query_chunk_attention, |
|
kv_chunk_size=kv_chunk_size, |
|
summarize_chunk=summarize_chunk, |
|
) |
|
) |
|
|
|
if q_tokens <= query_chunk_size: |
|
# fast-path for when there's just 1 query chunk |
|
return compute_query_chunk_attn( |
|
query=query, |
|
key_t=key_t, |
|
value=value, |
|
) |
|
|
|
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, |
|
# and pass slices to be mutated, instead of torch.cat()ing the returned slices |
|
res = torch.cat([ |
|
compute_query_chunk_attn( |
|
query=get_query_chunk(i * query_chunk_size), |
|
key_t=key_t, |
|
value=value, |
|
) for i in range(math.ceil(q_tokens / query_chunk_size)) |
|
], dim=1) |
|
return res
|
|
|