From 52fb4b729be7fc35e49af12910e38d141c66834d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 16 Oct 2022 12:51:26 -0700 Subject: [PATCH] Fix #54: set device for multi-GPU case --- csrc/flash_attn/fmha_api.cpp | 7 ++++ tests/test_flash_attn.py | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index b6d976a..0539a60 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -28,6 +28,7 @@ #include #include +#include #include "fmha.h" @@ -246,6 +247,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.get_device()}; + auto opts = q.options(); // auto o = torch::empty({ total_q, num_heads, head_size }, opts); @@ -400,6 +404,9 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.get_device()}; + // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different. auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous(); diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index b0edc3a..78885c0 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -772,3 +772,78 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): assert torch.equal(dq_unpad, dq_unpad_0) assert torch.equal(dk_unpad, dk_unpad_0) assert torch.equal(dv_unpad, dv_unpad_0) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='requires multiple GPUs') +def test_flash_attn_multigpu(): + seqlen = 256 + d = 64 + dropout_p = 0.0 + causal = False + dtype = torch.float16 + device = 'cuda:1' + torch.random.manual_seed(0) + batch_size = 32 + nheads = 4 + x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) + Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) + + key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') + # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') + + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True + ) + + output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func( + qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal + ) + output = output_pad_fn(output_unpad) + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal + ) + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], + key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal) + dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask, + causal=causal).item() + + output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, + causal=causal) + output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, + causal=causal, upcast=False, reorder_ops=True) + print(f'Actual dropout fraction: {dropout_fraction}') + print(f'Output max diff: {(output - output_ref).abs().max().item()}') + print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') + print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') + print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') + print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') + print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') + + g = torch.randn_like(output) + dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) + dqkv = dqkv_pad_fn(dqkv_unpad) + dqkv_ref, = torch.autograd.grad(output_ref, qkv, g) + dqkv_pt, = torch.autograd.grad(output_pt, qkv, g) + print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') + print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') + print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') + print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') + print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') + print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') + print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') + print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() + # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) + assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) + if dropout_p == 0.0: + assert dropout_mask.all() + else: + assert 0.99 <= dropout_fraction / dropout_p <= 1.01 + + assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()