Support block size 32 (#35)
This commit is contained in:
parent
ee88a7e5f3
commit
b9926f7f66
@ -15,9 +15,9 @@ class BlockAllocator:
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
if block_size not in [8, 16]:
|
||||
if block_size not in [8, 16, 32]:
|
||||
raise ValueError(f'Unsupported block size: {block_size}'
|
||||
'The block size must be either 8 or 16.')
|
||||
'The block size must be one of {8, 16, 32}.')
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
@ -174,7 +174,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
|
||||
@ -654,6 +654,16 @@ void single_query_cached_kv_attention(
|
||||
block_tables,
|
||||
context_lens,
|
||||
max_context_len);
|
||||
} else if (block_size == 32) {
|
||||
single_query_cached_kv_attention_launcher<uint16_t, 32>(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
max_context_len);
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
@ -679,6 +689,16 @@ void single_query_cached_kv_attention(
|
||||
block_tables,
|
||||
context_lens,
|
||||
max_context_len);
|
||||
} else if (block_size == 32) {
|
||||
single_query_cached_kv_attention_launcher<float, 32>(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
max_context_len);
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
@ -834,6 +854,18 @@ void multi_query_cached_kv_attention(
|
||||
block_tables,
|
||||
context_lens,
|
||||
max_context_len);
|
||||
} else if (block_size == 32) {
|
||||
multi_query_cached_kv_attention_launcher<uint16_t, 32>(
|
||||
cu_query_lens,
|
||||
seq_prompt_mapping,
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
max_context_len);
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
@ -863,6 +895,18 @@ void multi_query_cached_kv_attention(
|
||||
block_tables,
|
||||
context_lens,
|
||||
max_context_len);
|
||||
} else if (block_size == 32) {
|
||||
multi_query_cached_kv_attention_launcher<float, 32>(
|
||||
cu_query_lens,
|
||||
seq_prompt_mapping,
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
max_context_len);
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
@ -350,7 +350,7 @@ def test_attention(seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
@ -368,7 +368,7 @@ def test_attention(seed: int) -> None:
|
||||
# note that the test is also more likely to fail due to the much
|
||||
# larger amount of tokens in the input may increase the variance.
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing multi_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
|
||||
Loading…
Reference in New Issue
Block a user