113 lines
2.9 KiB
Python
113 lines
2.9 KiB
Python
import torch
|
|
from flash_attn import flash_attn_varlen_func
|
|
from torch.nn import functional as F
|
|
|
|
import os
|
|
|
|
os.environ["TORCH_USE_CUDA_DSA"] = "1"
|
|
|
|
|
|
def generate_data():
|
|
nhead = 32
|
|
head_dim = 128
|
|
|
|
batch_size = 3
|
|
k_past_len = [16, 32, 64]
|
|
|
|
k_data = []
|
|
v_data = []
|
|
q_data = []
|
|
for i in k_past_len:
|
|
k_data.append(
|
|
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
|
|
)
|
|
v_data.append(
|
|
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
|
|
)
|
|
q_data.append(
|
|
torch.randn(size=(1, nhead, head_dim), dtype=torch.float16, device="cuda")
|
|
)
|
|
|
|
k_data = torch.concat(k_data, dim=0)
|
|
v_data = torch.concat(v_data, dim=0)
|
|
q_data = torch.concat(q_data, dim=0)
|
|
torch.save(k_data, "k.pkl")
|
|
torch.save(v_data, "v.pkl")
|
|
torch.save(q_data, "q.pkl")
|
|
|
|
|
|
def test_fa():
|
|
nhead = 32
|
|
head_dim = 128
|
|
|
|
batch_size = 3
|
|
k_past_len = [16, 32, 64]
|
|
|
|
k_data = []
|
|
v_data = []
|
|
q_data = []
|
|
k_data = torch.load("./k.pkl")
|
|
q_data = torch.load("./q.pkl")
|
|
v_data = torch.load("./v.pkl")
|
|
print(k_data.size(), v_data.size(), q_data.size())
|
|
k_past_len = torch.tensor(k_past_len, dtype=torch.int32, device="cuda")
|
|
res = torch.cumsum(k_past_len, dim=-1)
|
|
k_past_len = torch.cat(
|
|
(res, torch.tensor([k_data.size(0)], device=res.device, dtype=torch.int32))
|
|
)
|
|
k_past_len = k_past_len.to(torch.int32)
|
|
|
|
# cu_k_len = torch.cumsum(torch.tensor(k_past_len))
|
|
# print(cu_k_len, cu_k_len.size())
|
|
torch.cuda.synchronize()
|
|
fa_res = flash_attn_varlen_func(
|
|
q_data,
|
|
k_data,
|
|
v_data,
|
|
cu_seqlens_k=k_past_len,
|
|
cu_seqlens_q=k_past_len,
|
|
max_seqlen_k=3,
|
|
max_seqlen_q=1,
|
|
causal=True,
|
|
)
|
|
# torch.cuda.synchronize()
|
|
torch.save(fa_res, "./fa.pkl")
|
|
|
|
|
|
def test_torch_sdqa():
|
|
nhead = 32
|
|
head_dim = 128
|
|
|
|
batch_size = 3
|
|
k_past_len = [16, 32, 64]
|
|
|
|
k_data = []
|
|
v_data = []
|
|
q_data = []
|
|
q_data = torch.load("./q.pkl")
|
|
k_data = torch.load("./k.pkl")
|
|
v_data = torch.load("./v.pkl")
|
|
q = q_data[0, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
|
k = k_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
|
v = v_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
|
# torch.cuda.synchronize()
|
|
naive_res = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
print(naive_res.size())
|
|
naive_res = naive_res.view(1, nhead, head_dim)
|
|
fa_res = torch.load("./fa.pkl")
|
|
if torch.allclose(naive_res, fa_res[0, :, :]):
|
|
print("this is dam right")
|
|
else:
|
|
print("max diff is ", torch.abs(naive_res - fa_res[0, :, :]).max())
|
|
# torch.cuda.synchronize()
|
|
|
|
|
|
# do not know if such method is right.
|
|
# chunk prefill did use such method?
|
|
|
|
|
|
if __name__ == "__main__":
|
|
generate_data()
|
|
test_fa()
|
|
test_torch_sdqa()
|