From 3123f151387d2afa49eaf3130bcee3556f2e87d2 Mon Sep 17 00:00:00 2001 From: Tao He Date: Sat, 16 Mar 2024 11:58:10 +0800 Subject: [PATCH] Fixes the incorrect argument in the prefix-prefill test cases (#3246) --- tests/kernels/test_prefix_prefill.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index a0be658a..4d051593 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -18,7 +18,7 @@ CUDA_DEVICES = [ @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -35,6 +35,13 @@ def test_contexted_kv_attention( if torch.cuda.is_available(): torch.cuda.manual_seed(0) torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process for GPU 1 would run on both + # GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + MAX_SEQ_LEN = 1024 MAX_CTX_LEN = 1024 BS = 10 @@ -172,5 +179,5 @@ def test_contexted_kv_attention( torch.cuda.synchronize() end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - output_ref = output_ref.squeeze(0, 2) + output_ref = output_ref.reshape(output.shape) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)