Enable GQA support in the prefix prefill kernels (#3007)
Signed-off-by: Tao He <sighingnow@gmail.com>
This commit is contained in:
parent
8b430d7dea
commit
71bcaf99e2
@ -8,7 +8,8 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
|
|||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||||
|
|
||||||
NUM_HEADS = [12]
|
NUM_HEADS = [64]
|
||||||
|
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||||
HEAD_SIZES = [128]
|
HEAD_SIZES = [128]
|
||||||
DTYPES = [torch.float16]
|
DTYPES = [torch.float16]
|
||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
@ -17,12 +18,14 @@ CUDA_DEVICES = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_contexted_kv_attention(
|
def test_contexted_kv_attention(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
num_queries_per_kv: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: str,
|
device: str,
|
||||||
@ -41,28 +44,29 @@ def test_contexted_kv_attention(
|
|||||||
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||||
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
|
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
|
||||||
|
num_kv_heads = num_heads // num_queries_per_kv
|
||||||
|
|
||||||
num_tokens = sum(subquery_lens)
|
num_tokens = sum(subquery_lens)
|
||||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||||
query.uniform_(-1e-3, 1e-3)
|
query.uniform_(-1e-3, 1e-3)
|
||||||
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||||
|
|
||||||
kv = torch.empty(sum(seq_lens), 2, num_heads, head_size, dtype=dtype)
|
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||||
kv.uniform_(-1e-3, 1e-3)
|
kv.uniform_(-1e-3, 1e-3)
|
||||||
key, value = kv.unbind(dim=1)
|
key, value = kv.unbind(dim=1)
|
||||||
|
|
||||||
k_cache = torch.zeros(cache_size,
|
k_cache = torch.zeros(cache_size,
|
||||||
block_size,
|
block_size,
|
||||||
num_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
v_cache = torch.zeros(cache_size,
|
v_cache = torch.zeros(cache_size,
|
||||||
block_size,
|
block_size,
|
||||||
num_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
|
k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
|
||||||
v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
|
v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
|
||||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||||
values = values[torch.randperm(cache_size)]
|
values = values[torch.randperm(cache_size)]
|
||||||
block_table = values[:BS * max_block_per_request].view(
|
block_table = values[:BS * max_block_per_request].view(
|
||||||
@ -93,19 +97,21 @@ def test_contexted_kv_attention(
|
|||||||
end_loc = start_loc + block_size
|
end_loc = start_loc + block_size
|
||||||
start_slot = block_table[i, block_id] * block_size
|
start_slot = block_table[i, block_id] * block_size
|
||||||
end_slot = start_slot + end_loc - start_loc
|
end_slot = start_slot + end_loc - start_loc
|
||||||
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
|
k_cache.view(-1, num_kv_heads,
|
||||||
key[start_loc:end_loc])
|
head_size)[start_slot:end_slot].copy_(
|
||||||
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
|
key[start_loc:end_loc])
|
||||||
value[start_loc:end_loc])
|
v_cache.view(-1, num_kv_heads,
|
||||||
|
head_size)[start_slot:end_slot].copy_(
|
||||||
|
value[start_loc:end_loc])
|
||||||
cur_ctx += block_size
|
cur_ctx += block_size
|
||||||
block_id += 1
|
block_id += 1
|
||||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||||
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8,
|
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
||||||
8).permute(0, 2, 3, 1, 4).contiguous()
|
8).permute(0, 2, 3, 1, 4).contiguous()
|
||||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||||
v_cache = v_cache.view(-1, block_size, num_heads,
|
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||||
head_size).permute(0, 2, 3, 1).contiguous()
|
head_size).permute(0, 2, 3, 1).contiguous()
|
||||||
|
|
||||||
# Warm up the Triton kernel by calling it once before actually measuring generation time
|
# Warm up the Triton kernel by calling it once before actually measuring generation time
|
||||||
@ -123,12 +129,29 @@ def test_contexted_kv_attention(
|
|||||||
|
|
||||||
attn_op = xops.fmha.cutlass.FwOp()
|
attn_op = xops.fmha.cutlass.FwOp()
|
||||||
|
|
||||||
|
if num_kv_heads != num_heads:
|
||||||
|
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||||
|
# project the key and value tensors to the desired number of
|
||||||
|
# heads.
|
||||||
|
#
|
||||||
|
# see also: vllm/model_executor/layers/attention.py
|
||||||
|
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
||||||
|
query.shape[-1])
|
||||||
|
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
||||||
|
num_queries_per_kv, key.shape[-1])
|
||||||
|
value = value[:, :,
|
||||||
|
None, :].expand(value.shape[0], num_kv_heads,
|
||||||
|
num_queries_per_kv, value.shape[-1])
|
||||||
|
query = query.unsqueeze(0)
|
||||||
|
key = key.unsqueeze(0)
|
||||||
|
value = value.unsqueeze(0)
|
||||||
|
|
||||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||||
subquery_lens, seq_lens)
|
subquery_lens, seq_lens)
|
||||||
output_ref = xops.memory_efficient_attention_forward(
|
output_ref = xops.memory_efficient_attention_forward(
|
||||||
query.unsqueeze(0),
|
query,
|
||||||
key.unsqueeze(0),
|
key,
|
||||||
value.unsqueeze(0),
|
value,
|
||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
@ -137,9 +160,9 @@ def test_contexted_kv_attention(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
output_ref = xops.memory_efficient_attention_forward(
|
output_ref = xops.memory_efficient_attention_forward(
|
||||||
query.unsqueeze(0),
|
query,
|
||||||
key.unsqueeze(0),
|
key,
|
||||||
value.unsqueeze(0),
|
value,
|
||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
@ -148,5 +171,5 @@ def test_contexted_kv_attention(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||||
output_ref = output_ref.squeeze(0)
|
output_ref = output_ref.squeeze(0, 2)
|
||||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||||
|
|||||||
@ -137,25 +137,27 @@ class PagedAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if input_metadata.is_prompt:
|
if input_metadata.is_prompt:
|
||||||
# Prompt run.
|
|
||||||
if self.num_kv_heads != self.num_heads:
|
|
||||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
|
||||||
# project the key and value tensors to the desired number of
|
|
||||||
# heads.
|
|
||||||
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
|
|
||||||
query = query.view(query.shape[0], self.num_kv_heads,
|
|
||||||
self.num_queries_per_kv, query.shape[-1])
|
|
||||||
key = key[:, :,
|
|
||||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
|
||||||
self.num_queries_per_kv,
|
|
||||||
key.shape[-1])
|
|
||||||
value = value[:, :, None, :].expand(value.shape[0],
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.num_queries_per_kv,
|
|
||||||
value.shape[-1])
|
|
||||||
# normal attention
|
# normal attention
|
||||||
if (key_cache is None or value_cache is None
|
if (key_cache is None or value_cache is None
|
||||||
or input_metadata.block_tables.numel() == 0):
|
or input_metadata.block_tables.numel() == 0):
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||||
|
# project the key and value tensors to the desired number of
|
||||||
|
# heads.
|
||||||
|
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
|
||||||
|
query = query.view(query.shape[0], self.num_kv_heads,
|
||||||
|
self.num_queries_per_kv,
|
||||||
|
query.shape[-1])
|
||||||
|
key = key[:, :,
|
||||||
|
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||||
|
self.num_queries_per_kv,
|
||||||
|
key.shape[-1])
|
||||||
|
value = value[:, :,
|
||||||
|
None, :].expand(value.shape[0],
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.num_queries_per_kv,
|
||||||
|
value.shape[-1])
|
||||||
|
|
||||||
# Set attention bias if not provided. This typically happens at
|
# Set attention bias if not provided. This typically happens at
|
||||||
# the very attention layer of every iteration.
|
# the very attention layer of every iteration.
|
||||||
# FIXME(woosuk): This is a hack.
|
# FIXME(woosuk): This is a hack.
|
||||||
|
|||||||
@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
stride_v_cache_h,
|
stride_v_cache_h,
|
||||||
stride_v_cache_d,
|
stride_v_cache_d,
|
||||||
stride_v_cache_bl,
|
stride_v_cache_bl,
|
||||||
|
num_queries_per_kv: int,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
start_m = tl.program_id(2)
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
|
|||||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
other=0)
|
other=0)
|
||||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||||
cur_head * stride_k_cache_h +
|
cur_kv_head * stride_k_cache_h +
|
||||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
((start_n + offs_n[None, :]) % block_size) *
|
((start_n + offs_n[None, :]) % block_size) *
|
||||||
stride_k_cache_bl +
|
stride_k_cache_bl +
|
||||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||||
off_v = (
|
off_v = (
|
||||||
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
|
bn[:, None] * stride_v_cache_bs +
|
||||||
|
cur_kv_head * stride_v_cache_h +
|
||||||
offs_d[None, :] * stride_v_cache_d +
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||||
k = tl.load(K_cache + off_k,
|
k = tl.load(K_cache + off_k,
|
||||||
@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
|
|||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
|
|
||||||
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
|
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||||
offs_d[:, None] * stride_kd)
|
offs_d[:, None] * stride_kd)
|
||||||
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
|
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||||
offs_d[None, :] * stride_vd)
|
offs_d[None, :] * stride_vd)
|
||||||
k_ptrs = K + off_k
|
k_ptrs = K + off_k
|
||||||
v_ptrs = V + off_v
|
v_ptrs = V + off_v
|
||||||
@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
stride_v_cache_h,
|
stride_v_cache_h,
|
||||||
stride_v_cache_d,
|
stride_v_cache_d,
|
||||||
stride_v_cache_bl,
|
stride_v_cache_bl,
|
||||||
|
num_queries_per_kv: int,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
start_m = tl.program_id(2)
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
|
|||||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
other=0)
|
other=0)
|
||||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||||
cur_head * stride_k_cache_h +
|
cur_kv_head * stride_k_cache_h +
|
||||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
((start_n + offs_n[None, :]) % block_size) *
|
((start_n + offs_n[None, :]) % block_size) *
|
||||||
stride_k_cache_bl +
|
stride_k_cache_bl +
|
||||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||||
off_v = (
|
off_v = (
|
||||||
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
|
bn[:, None] * stride_v_cache_bs +
|
||||||
|
cur_kv_head * stride_v_cache_h +
|
||||||
offs_d[None, :] * stride_v_cache_d +
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||||
k = tl.load(K_cache + off_k,
|
k = tl.load(K_cache + off_k,
|
||||||
@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
|
|||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
|
|
||||||
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
|
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||||
offs_d[:, None] * stride_kd)
|
offs_d[:, None] * stride_kd)
|
||||||
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
|
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||||
offs_d[None, :] * stride_vd)
|
offs_d[None, :] * stride_vd)
|
||||||
k_ptrs = K + off_k
|
k_ptrs = K + off_k
|
||||||
v_ptrs = V + off_v
|
v_ptrs = V + off_v
|
||||||
@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
stride_v_cache_h,
|
stride_v_cache_h,
|
||||||
stride_v_cache_d,
|
stride_v_cache_d,
|
||||||
stride_v_cache_bl,
|
stride_v_cache_bl,
|
||||||
|
num_queries_per_kv: int,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
|
|||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
start_m = tl.program_id(2)
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
# cur_batch_seq_len: the length of prompts
|
# cur_batch_seq_len: the length of prompts
|
||||||
# cur_batch_ctx_len: the length of prefix
|
# cur_batch_ctx_len: the length of prefix
|
||||||
# cur_batch_in_all_start_index: the start id of the dim=0
|
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||||
@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
|
|||||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||||
other=0)
|
other=0)
|
||||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||||
cur_head * stride_k_cache_h +
|
cur_kv_head * stride_k_cache_h +
|
||||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||||
((start_n + offs_n[None, :]) % block_size) *
|
((start_n + offs_n[None, :]) % block_size) *
|
||||||
stride_k_cache_bl +
|
stride_k_cache_bl +
|
||||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||||
off_v = (
|
off_v = (
|
||||||
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
|
bn[:, None] * stride_v_cache_bs +
|
||||||
|
cur_kv_head * stride_v_cache_h +
|
||||||
offs_d[None, :] * stride_v_cache_d +
|
offs_d[None, :] * stride_v_cache_d +
|
||||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||||
k = tl.load(K_cache + off_k,
|
k = tl.load(K_cache + off_k,
|
||||||
@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
|
|||||||
l_i = l_i_new
|
l_i = l_i_new
|
||||||
m_i = m_i_new
|
m_i = m_i_new
|
||||||
|
|
||||||
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
|
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
|
||||||
offs_d[:, None] * stride_kd)
|
offs_d[:, None] * stride_kd)
|
||||||
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
|
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
|
||||||
offs_d[None, :] * stride_vd)
|
offs_d[None, :] * stride_vd)
|
||||||
k_ptrs = K + off_k
|
k_ptrs = K + off_k
|
||||||
v_ptrs = V + off_v
|
v_ptrs = V + off_v
|
||||||
@ -628,6 +640,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
|
|
||||||
sm_scale = 1.0 / (Lq**0.5)
|
sm_scale = 1.0 / (Lq**0.5)
|
||||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
|
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||||
|
|
||||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
||||||
|
|
||||||
@ -674,6 +687,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_cache.stride(2),
|
v_cache.stride(2),
|
||||||
v_cache.stride(
|
v_cache.stride(
|
||||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
BLOCK_M=BLOCK,
|
BLOCK_M=BLOCK,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
@ -721,6 +735,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_cache.stride(2),
|
v_cache.stride(2),
|
||||||
v_cache.stride(
|
v_cache.stride(
|
||||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
num_queries_per_kv=num_queries_per_kv,
|
||||||
BLOCK_M=BLOCK,
|
BLOCK_M=BLOCK,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user