|
|
|
@ -114,7 +114,8 @@ def attention_basic(q, k, v, heads, mask=None):
|
|
|
|
|
mask = repeat(mask, 'b j -> (b h) () j', h=h) |
|
|
|
|
sim.masked_fill_(~mask, max_neg_value) |
|
|
|
|
else: |
|
|
|
|
sim += mask |
|
|
|
|
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(sim.shape) |
|
|
|
|
sim.add_(mask) |
|
|
|
|
|
|
|
|
|
# attention, what we cannot get enough of |
|
|
|
|
sim = sim.softmax(dim=-1) |
|
|
|
@ -165,6 +166,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|
|
|
|
if query_chunk_size is None: |
|
|
|
|
query_chunk_size = 512 |
|
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
|
|
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) |
|
|
|
|
|
|
|
|
|
hidden_states = efficient_dot_product_attention( |
|
|
|
|
query, |
|
|
|
|
key, |
|
|
|
@ -223,6 +227,9 @@ def attention_split(q, k, v, heads, mask=None):
|
|
|
|
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' |
|
|
|
|
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') |
|
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
|
|
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) |
|
|
|
|
|
|
|
|
|
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) |
|
|
|
|
first_op_done = False |
|
|
|
|
cleared_cache = False |
|
|
|
|