diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 9352d6c..5d7296f 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -387,6 +387,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } @@ -589,6 +591,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int head_size_og = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index f5495ca..eb8bcea 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -73,7 +73,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 40e31cb..800dbd8 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -895,12 +895,14 @@ def test_flash_attn_output( and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM + if softcap > 0.0 and dropout_p > 0.0: + pytest.skip("Softcap and dropout not supported together") device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 - nheads = 9 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) @@ -1162,12 +1164,14 @@ def test_flash_attn_varlen_output( and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM + if softcap > 0.0 and dropout_p > 0.0: + pytest.skip("Softcap and dropout not supported together") device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 - nheads = 9 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)