torch_ext/test_fa.py
2025-03-28 22:19:03 +08:00

113 lines
2.9 KiB
Python

import torch
from flash_attn import flash_attn_varlen_func
from torch.nn import functional as F
import os
os.environ["TORCH_USE_CUDA_DSA"] = "1"
def generate_data():
nhead = 32
head_dim = 128
batch_size = 3
k_past_len = [16, 32, 64]
k_data = []
v_data = []
q_data = []
for i in k_past_len:
k_data.append(
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
)
v_data.append(
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
)
q_data.append(
torch.randn(size=(1, nhead, head_dim), dtype=torch.float16, device="cuda")
)
k_data = torch.concat(k_data, dim=0)
v_data = torch.concat(v_data, dim=0)
q_data = torch.concat(q_data, dim=0)
torch.save(k_data, "k.pkl")
torch.save(v_data, "v.pkl")
torch.save(q_data, "q.pkl")
def test_fa():
nhead = 32
head_dim = 128
batch_size = 3
k_past_len = [16, 32, 64]
k_data = []
v_data = []
q_data = []
k_data = torch.load("./k.pkl")
q_data = torch.load("./q.pkl")
v_data = torch.load("./v.pkl")
print(k_data.size(), v_data.size(), q_data.size())
k_past_len = torch.tensor(k_past_len, dtype=torch.int32, device="cuda")
res = torch.cumsum(k_past_len, dim=-1)
k_past_len = torch.cat(
(res, torch.tensor([k_data.size(0)], device=res.device, dtype=torch.int32))
)
k_past_len = k_past_len.to(torch.int32)
# cu_k_len = torch.cumsum(torch.tensor(k_past_len))
# print(cu_k_len, cu_k_len.size())
torch.cuda.synchronize()
fa_res = flash_attn_varlen_func(
q_data,
k_data,
v_data,
cu_seqlens_k=k_past_len,
cu_seqlens_q=k_past_len,
max_seqlen_k=3,
max_seqlen_q=1,
causal=True,
)
# torch.cuda.synchronize()
torch.save(fa_res, "./fa.pkl")
def test_torch_sdqa():
nhead = 32
head_dim = 128
batch_size = 3
k_past_len = [16, 32, 64]
k_data = []
v_data = []
q_data = []
q_data = torch.load("./q.pkl")
k_data = torch.load("./k.pkl")
v_data = torch.load("./v.pkl")
q = q_data[0, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
k = k_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
v = v_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
# torch.cuda.synchronize()
naive_res = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(naive_res.size())
naive_res = naive_res.view(1, nhead, head_dim)
fa_res = torch.load("./fa.pkl")
if torch.allclose(naive_res, fa_res[0, :, :]):
print("this is dam right")
else:
print("max diff is ", torch.abs(naive_res - fa_res[0, :, :]).max())
# torch.cuda.synchronize()
# do not know if such method is right.
# chunk prefill did use such method?
if __name__ == "__main__":
generate_data()
test_fa()
test_torch_sdqa()