|
|
|
@ -194,7 +194,14 @@ class CrossAttentionBirchSan(nn.Module):
|
|
|
|
|
|
|
|
|
|
kv_chunk_size_min = None |
|
|
|
|
|
|
|
|
|
query_chunk_size_x = 1024 * 4 |
|
|
|
|
#not sure at all about the math here |
|
|
|
|
#TODO: tweak this |
|
|
|
|
if mem_free_total > 8192 * 1024 * 1024 * 1.3: |
|
|
|
|
query_chunk_size_x = 1024 * 4 |
|
|
|
|
elif mem_free_total > 4096 * 1024 * 1024 * 1.3: |
|
|
|
|
query_chunk_size_x = 1024 * 2 |
|
|
|
|
else: |
|
|
|
|
query_chunk_size_x = 1024 |
|
|
|
|
kv_chunk_size_min_x = None |
|
|
|
|
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 |
|
|
|
|
if kv_chunk_size_x < 1024: |
|
|
|
|