Replace head_mapping params with num_kv_heads to attention kernel. (#1997)

Co-authored-by: wangguoya <wangguoya@baidu.com>
Co-authored-by: Yang Zhao <zhaoyangstar@foxmail.com>
This commit is contained in:
wbn 2023-12-11 02:12:53 +08:00 committed by GitHub
parent 24cde76a15
commit dacaf5a400
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 37 deletions

View File

@ -37,10 +37,6 @@ def main(
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
@ -103,7 +99,7 @@ def main(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,
@ -120,7 +116,7 @@ def main(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,

View File

@ -89,7 +89,7 @@ __device__ void paged_attention_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
@ -132,7 +132,8 @@ __device__ void paged_attention_kernel(
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int kv_head_idx = head_mapping[head_idx];
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
@ -401,7 +402,7 @@ __global__ void paged_attention_v1_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
@ -412,7 +413,7 @@ __global__ void paged_attention_v1_kernel(
const int kv_head_stride) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
}
@ -430,7 +431,7 @@ __global__ void paged_attention_v2_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
@ -440,7 +441,7 @@ __global__ void paged_attention_v2_kernel(
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride);
}
@ -556,7 +557,7 @@ __global__ void paged_attention_v2_reduce_kernel(
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
@ -576,7 +577,7 @@ void paged_attention_v1_launcher(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
@ -602,7 +603,6 @@ void paged_attention_v1_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
@ -651,7 +651,7 @@ void paged_attention_v1_launcher(
query, \
key_cache, \
value_cache, \
head_mapping, \
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
@ -681,7 +681,7 @@ void paged_attention_v1(
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
@ -708,7 +708,7 @@ void paged_attention_v1(
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
@ -739,7 +739,7 @@ void paged_attention_v2_launcher(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
@ -768,7 +768,6 @@ void paged_attention_v2_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
@ -823,7 +822,7 @@ void paged_attention_v2_launcher(
query, \
key_cache, \
value_cache, \
head_mapping, \
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
@ -856,7 +855,7 @@ void paged_attention_v2(
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]

View File

@ -5,7 +5,7 @@ void paged_attention_v1(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
@ -21,7 +21,7 @@ void paged_attention_v2(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,

View File

@ -131,9 +131,6 @@ def test_paged_attention(
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
@ -170,7 +167,7 @@ def test_paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,
@ -202,7 +199,7 @@ def test_paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,

View File

@ -54,9 +54,6 @@ class PagedAttention(nn.Module):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)
if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
@ -77,7 +74,7 @@ class PagedAttention(nn.Module):
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
@ -172,7 +169,7 @@ class PagedAttention(nn.Module):
key_cache,
value_cache,
input_metadata,
self.head_mapping,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
@ -217,7 +214,7 @@ def _paged_attention(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
head_mapping: torch.Tensor,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
@ -244,7 +241,7 @@ def _paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
@ -274,7 +271,7 @@ def _paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,