Fix #54: set device for multi-GPU case

This commit is contained in:
Tri Dao 2022-10-16 12:51:26 -07:00
parent 1b9facacc3
commit 52fb4b729b
2 changed files with 82 additions and 0 deletions

View File

@ -28,6 +28,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "fmha.h"
@ -246,6 +247,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
bool loop = max_seqlen_k > blocksize_c;
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.get_device()};
auto opts = q.options();
// auto o = torch::empty({ total_q, num_heads, head_size }, opts);
@ -400,6 +404,9 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
bool loop = max_seqlen_k > blocksize_c;
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.get_device()};
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();

View File

@ -772,3 +772,78 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
assert torch.equal(dq_unpad, dq_unpad_0)
assert torch.equal(dk_unpad, dk_unpad_0)
assert torch.equal(dv_unpad, dv_unpad_0)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='requires multiple GPUs')
def test_flash_attn_multigpu():
seqlen = 256
d = 64
dropout_p = 0.0
causal = False
dtype = torch.float16
device = 'cuda:1'
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
)
output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
)
output = output_pad_fn(output_unpad)
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
causal=causal).item()
output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal)
output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Actual dropout fraction: {dropout_fraction}')
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
g = torch.randn_like(output)
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if dropout_p == 0.0:
assert dropout_mask.all()
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()