Support various block sizes & Change default block size to 16 (#38)
This commit is contained in:
parent
84eee24e20
commit
0f4b32199e
@ -268,6 +268,7 @@ if __name__ == '__main__':
|
|||||||
f'{model_name}-tp{args.tensor_parallel_size}',
|
f'{model_name}-tp{args.tensor_parallel_size}',
|
||||||
sample_dir,
|
sample_dir,
|
||||||
'cacheflow',
|
'cacheflow',
|
||||||
|
f'block{args.block_size}',
|
||||||
f'req-rate-{args.request_rate}',
|
f'req-rate-{args.request_rate}',
|
||||||
f'seed{args.seed}',
|
f'seed{args.seed}',
|
||||||
f'duration-{args.duration}',
|
f'duration-{args.duration}',
|
||||||
|
|||||||
@ -15,9 +15,6 @@ class BlockAllocator:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if block_size not in [8, 16, 32]:
|
|
||||||
raise ValueError(f'Unsupported block size: {block_size}'
|
|
||||||
'The block size must be one of {8, 16, 32}.')
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
|
|||||||
@ -125,7 +125,8 @@ class Scheduler:
|
|||||||
|
|
||||||
# Swap in the sequence groups in the SWAPPED state if possible.
|
# Swap in the sequence groups in the SWAPPED state if possible.
|
||||||
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
||||||
while self.swapped:
|
# FCFS
|
||||||
|
while self.swapped and not blocks_to_swap_out:
|
||||||
seq_group = self.swapped[0]
|
seq_group = self.swapped[0]
|
||||||
# If the sequence group has been preempted in this step, stop.
|
# If the sequence group has been preempted in this step, stop.
|
||||||
if seq_group in preempted:
|
if seq_group in preempted:
|
||||||
|
|||||||
@ -180,9 +180,9 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size')
|
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
|
||||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||||
|
|||||||
@ -11,25 +11,9 @@ void single_query_cached_kv_attention(
|
|||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len);
|
int max_context_len);
|
||||||
|
|
||||||
void multi_query_cached_kv_attention(
|
|
||||||
torch::Tensor& cu_query_lens,
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& context_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_context_len);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"single_query_cached_kv_attention",
|
"single_query_cached_kv_attention",
|
||||||
&single_query_cached_kv_attention,
|
&single_query_cached_kv_attention,
|
||||||
"Compute the attention between an input query and the cached key/value tensors");
|
"Compute the attention between an input query and the cached key/value tensors");
|
||||||
m.def(
|
|
||||||
"multi_query_cached_kv_attention",
|
|
||||||
&multi_query_cached_kv_attention,
|
|
||||||
"Compute the attention between multiple input queries and the cached key/value tensors");
|
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1074,6 +1074,21 @@ inline __device__ float sum(Float8_ v)
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline __device__ float dot(float a, float b)
|
||||||
|
{
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline __device__ float dot(float2 a, float2 b)
|
||||||
|
{
|
||||||
|
float2 c = mul<float2, float2, float2>(a, b);
|
||||||
|
return c.x + c.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
inline __device__ float dot(Float4_ a, Float4_ b)
|
inline __device__ float dot(Float4_ a, Float4_ b)
|
||||||
{
|
{
|
||||||
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
||||||
@ -1253,37 +1268,44 @@ inline __device__ float convert_to_float(uint4 u)
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
inline __device__ float cast_to_float(float u)
|
// inline __device__ float cast_to_float(float u)
|
||||||
{
|
// {
|
||||||
return u;
|
// return u;
|
||||||
}
|
// }
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
inline __device__ float2 cast_to_float(float2 u)
|
// inline __device__ float2 cast_to_float(float2 u)
|
||||||
{
|
// {
|
||||||
return u;
|
// return u;
|
||||||
}
|
// }
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
inline __device__ float4 cast_to_float(float4 u)
|
// inline __device__ float4 cast_to_float(float4 u)
|
||||||
{
|
// {
|
||||||
return u;
|
// return u;
|
||||||
}
|
// }
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
inline __device__ Float4_ cast_to_float(Float4_ u)
|
// inline __device__ Float4_ cast_to_float(Float4_ u)
|
||||||
{
|
// {
|
||||||
return u;
|
// return u;
|
||||||
}
|
// }
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
inline __device__ Float8_ cast_to_float(Float8_ u)
|
// inline __device__ Float8_ cast_to_float(Float8_ u)
|
||||||
|
// {
|
||||||
|
// return u;
|
||||||
|
// }
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline __device__ float cast_to_float(uint16_t u)
|
||||||
{
|
{
|
||||||
return u;
|
return half_to_float(u);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user