[Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi (#4573)
This commit is contained in:
parent
f6a593093a
commit
0f9a6e3d22
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
@ -6,11 +7,12 @@ import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||
HEAD_SIZES = [128, 96]
|
||||
HEAD_SIZES = [128, 96, 24]
|
||||
DTYPES = [torch.float16]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
@ -207,3 +209,242 @@ def test_contexted_kv_attention(
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
output_ref = output_ref.reshape(output.shape)
|
||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_alibi(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process
|
||||
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
|
||||
#
|
||||
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
|
||||
|
||||
MAX_SEQ_LEN = 1024
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
num_tokens = sum(query_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
|
||||
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:BS * max_block_per_request].view(
|
||||
BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||
j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||
b_ctx_len[i] + j])
|
||||
cur_ctx = 0
|
||||
block_id = 0
|
||||
while cur_ctx < b_ctx_len[i]:
|
||||
start_loc = b_seq_start_loc[i] + cur_ctx
|
||||
if cur_ctx + block_size > b_ctx_len[i]:
|
||||
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
||||
else:
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc])
|
||||
v_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc])
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
||||
8).permute(0, 2, 3, 1, 4).contiguous()
|
||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
context_attention_fwd(query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
alibi_slopes=alibi_slopes)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
context_attention_fwd(query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
alibi_slopes=alibi_slopes)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
|
||||
# we have to pad query tensor before MQA/GQA expanding.
|
||||
if query.shape[0] != key.shape[0]:
|
||||
query_pad = torch.empty(sum(seq_lens),
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query_pad.uniform_(-1e-3, 1e-3)
|
||||
seq_start = 0
|
||||
query_start = 0
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
query_pad[seq_start:seq_end, ...] = torch.cat([
|
||||
torch.zeros(
|
||||
seq_len - query_len, num_heads, head_size, dtype=dtype),
|
||||
query[query_start:query_end, ...]
|
||||
],
|
||||
dim=0)
|
||||
seq_start += seq_len
|
||||
query_start += query_len
|
||||
query = query_pad
|
||||
|
||||
if num_kv_heads != num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
||||
query.shape[-1])
|
||||
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, value.shape[-1])
|
||||
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
||||
output_ref = torch.empty_like(output)
|
||||
seq_start = 0
|
||||
query_start = 0
|
||||
start_time = time.time()
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(DefTruth): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
# modified from: vllm/attention/backends/xformers.py#L343
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
out = xops.memory_efficient_attention_forward(query[:,
|
||||
seq_start:seq_end],
|
||||
key[:,
|
||||
seq_start:seq_end],
|
||||
value[:,
|
||||
seq_start:seq_end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=scale)
|
||||
out = out.view_as(query[:, seq_start:seq_end]).view(
|
||||
seq_len, num_heads, head_size)
|
||||
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
|
||||
...])
|
||||
seq_start += seq_len
|
||||
query_start += query_len
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||
|
||||
@ -472,7 +472,8 @@ if triton.__version__ >= "2.1.0":
|
||||
stride_v_cache_bl,
|
||||
num_queries_per_kv: int,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, # head size
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# attn_bias[]
|
||||
@ -493,21 +494,24 @@ if triton.__version__ >= "2.1.0":
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
dim_mask = tl.where(
|
||||
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
|
||||
|
||||
q = tl.load(Q + off_q,
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(
|
||||
@ -532,8 +536,9 @@ if triton.__version__ >= "2.1.0":
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k = tl.load(K_cache + off_k,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
|
||||
other=0.0) # [D,N]
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
@ -567,7 +572,8 @@ if triton.__version__ >= "2.1.0":
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_cache + off_v,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
@ -600,8 +606,9 @@ if triton.__version__ >= "2.1.0":
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len,
|
||||
mask=dim_mask[:, None] &
|
||||
((start_n + offs_n[None, :]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
@ -637,8 +644,9 @@ if triton.__version__ >= "2.1.0":
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len,
|
||||
mask=dim_mask[None, :] &
|
||||
((start_n + offs_n[:, None]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len),
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
@ -656,7 +664,8 @@ if triton.__version__ >= "2.1.0":
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
mask=dim_mask[None, :] &
|
||||
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
|
||||
return
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -690,7 +699,6 @@ if triton.__version__ >= "2.1.0":
|
||||
|
||||
num_warps = 8 if Lk <= 64 else 8
|
||||
if alibi_slopes is not None:
|
||||
assert Lk == Lk_padded
|
||||
_fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
@ -735,6 +743,7 @@ if triton.__version__ >= "2.1.0":
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user