Fix #54: set device for multi-GPU case
This commit is contained in:
parent
1b9facacc3
commit
52fb4b729b
@ -28,6 +28,7 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "fmha.h"
|
||||
|
||||
@ -246,6 +247,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
at::cuda::CUDAGuard device_guard{q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
// auto o = torch::empty({ total_q, num_heads, head_size }, opts);
|
||||
@ -400,6 +404,9 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
at::cuda::CUDAGuard device_guard{q.get_device()};
|
||||
|
||||
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
|
||||
auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();
|
||||
|
||||
|
||||
@ -772,3 +772,78 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
|
||||
assert torch.equal(dq_unpad, dq_unpad_0)
|
||||
assert torch.equal(dk_unpad, dk_unpad_0)
|
||||
assert torch.equal(dv_unpad, dv_unpad_0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='requires multiple GPUs')
|
||||
def test_flash_attn_multigpu():
|
||||
seqlen = 256
|
||||
d = 64
|
||||
dropout_p = 0.0
|
||||
causal = False
|
||||
dtype = torch.float16
|
||||
device = 'cuda:1'
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 32
|
||||
nheads = 4
|
||||
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
|
||||
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
|
||||
|
||||
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
|
||||
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
|
||||
|
||||
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
|
||||
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
|
||||
)
|
||||
|
||||
output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
attn_unnorm = S_dmask_converted.abs()
|
||||
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
|
||||
key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
|
||||
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
|
||||
causal=causal).item()
|
||||
|
||||
output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
|
||||
causal=causal)
|
||||
output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
|
||||
causal=causal, upcast=False, reorder_ops=True)
|
||||
print(f'Actual dropout fraction: {dropout_fraction}')
|
||||
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
|
||||
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
|
||||
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
|
||||
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
|
||||
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
|
||||
|
||||
g = torch.randn_like(output)
|
||||
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
|
||||
dqkv = dqkv_pad_fn(dqkv_unpad)
|
||||
dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
|
||||
dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
|
||||
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
|
||||
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
|
||||
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
|
||||
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
|
||||
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
|
||||
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
|
||||
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
|
||||
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
|
||||
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
|
||||
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
|
||||
if dropout_p == 0.0:
|
||||
assert dropout_mask.all()
|
||||
else:
|
||||
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
|
||||
|
||||
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user