Support block size 32 (#35)

This commit is contained in:
Woosuk Kwon 2023-04-09 23:07:18 -07:00 committed by GitHub
parent ee88a7e5f3
commit b9926f7f66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 5 deletions

View File

@ -15,9 +15,9 @@ class BlockAllocator:
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
) -> None: ) -> None:
if block_size not in [8, 16]: if block_size not in [8, 16, 32]:
raise ValueError(f'Unsupported block size: {block_size}' raise ValueError(f'Unsupported block size: {block_size}'
'The block size must be either 8 or 16.') 'The block size must be one of {8, 16, 32}.')
self.device = device self.device = device
self.block_size = block_size self.block_size = block_size
self.num_blocks = num_blocks self.num_blocks = num_blocks

View File

@ -174,7 +174,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported. # NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).

View File

@ -654,6 +654,16 @@ void single_query_cached_kv_attention(
block_tables, block_tables,
context_lens, context_lens,
max_context_len); max_context_len);
} else if (block_size == 32) {
single_query_cached_kv_attention_launcher<uint16_t, 32>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else { } else {
assert(false); assert(false);
} }
@ -679,6 +689,16 @@ void single_query_cached_kv_attention(
block_tables, block_tables,
context_lens, context_lens,
max_context_len); max_context_len);
} else if (block_size == 32) {
single_query_cached_kv_attention_launcher<float, 32>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else { } else {
assert(false); assert(false);
} }
@ -834,6 +854,18 @@ void multi_query_cached_kv_attention(
block_tables, block_tables,
context_lens, context_lens,
max_context_len); max_context_len);
} else if (block_size == 32) {
multi_query_cached_kv_attention_launcher<uint16_t, 32>(
cu_query_lens,
seq_prompt_mapping,
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else { } else {
assert(false); assert(false);
} }
@ -863,6 +895,18 @@ void multi_query_cached_kv_attention(
block_tables, block_tables,
context_lens, context_lens,
max_context_len); max_context_len);
} else if (block_size == 32) {
multi_query_cached_kv_attention_launcher<float, 32>(
cu_query_lens,
seq_prompt_mapping,
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else { } else {
assert(false); assert(false);
} }

View File

@ -350,7 +350,7 @@ def test_attention(seed: int) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
for dtype in [torch.half, torch.float]: for dtype in [torch.half, torch.float]:
for block_size in [8, 16]: for block_size in [8, 16, 32]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with ' print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, ' f'dtype={dtype}, block_size={block_size}, '
@ -368,7 +368,7 @@ def test_attention(seed: int) -> None:
# note that the test is also more likely to fail due to the much # note that the test is also more likely to fail due to the much
# larger amount of tokens in the input may increase the variance. # larger amount of tokens in the input may increase the variance.
for dtype in [torch.half, torch.float]: for dtype in [torch.half, torch.float]:
for block_size in [8, 16]: for block_size in [8, 16, 32]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing multi_query_cached_kv_attention with ' print(f'Testing multi_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, ' f'dtype={dtype}, block_size={block_size}, '