import torch from flash_attn import flash_attn_varlen_func from torch.nn import functional as F import os os.environ["TORCH_USE_CUDA_DSA"] = "1" def generate_data(): nhead = 32 head_dim = 128 batch_size = 3 k_past_len = [16, 32, 64] k_data = [] v_data = [] q_data = [] for i in k_past_len: k_data.append( torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda") ) v_data.append( torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda") ) q_data.append( torch.randn(size=(1, nhead, head_dim), dtype=torch.float16, device="cuda") ) k_data = torch.concat(k_data, dim=0) v_data = torch.concat(v_data, dim=0) q_data = torch.concat(q_data, dim=0) torch.save(k_data, "k.pkl") torch.save(v_data, "v.pkl") torch.save(q_data, "q.pkl") def test_fa(): nhead = 32 head_dim = 128 batch_size = 3 k_past_len = [16, 32, 64] k_data = [] v_data = [] q_data = [] k_data = torch.load("./k.pkl") q_data = torch.load("./q.pkl") v_data = torch.load("./v.pkl") print(k_data.size(), v_data.size(), q_data.size()) k_past_len = torch.tensor(k_past_len, dtype=torch.int32, device="cuda") res = torch.cumsum(k_past_len, dim=-1) k_past_len = torch.cat( (res, torch.tensor([k_data.size(0)], device=res.device, dtype=torch.int32)) ) k_past_len = k_past_len.to(torch.int32) # cu_k_len = torch.cumsum(torch.tensor(k_past_len)) # print(cu_k_len, cu_k_len.size()) torch.cuda.synchronize() fa_res = flash_attn_varlen_func( q_data, k_data, v_data, cu_seqlens_k=k_past_len, cu_seqlens_q=k_past_len, max_seqlen_k=3, max_seqlen_q=1, causal=True, ) # torch.cuda.synchronize() torch.save(fa_res, "./fa.pkl") def test_torch_sdqa(): nhead = 32 head_dim = 128 batch_size = 3 k_past_len = [16, 32, 64] k_data = [] v_data = [] q_data = [] q_data = torch.load("./q.pkl") k_data = torch.load("./k.pkl") v_data = torch.load("./v.pkl") q = q_data[0, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous() k = k_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous() v = v_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous() # torch.cuda.synchronize() naive_res = F.scaled_dot_product_attention(q, k, v, is_causal=True) print(naive_res.size()) naive_res = naive_res.view(1, nhead, head_dim) fa_res = torch.load("./fa.pkl") if torch.allclose(naive_res, fa_res[0, :, :]): print("this is dam right") else: print("max diff is ", torch.abs(naive_res - fa_res[0, :, :]).max()) # torch.cuda.synchronize() # do not know if such method is right. # chunk prefill did use such method? if __name__ == "__main__": generate_data() test_fa() test_torch_sdqa()