Don't support softcap and dropout at the same time

These tests are failing so I'm just disabling this case for now
This commit is contained in:
Tri Dao 2024-07-10 11:23:12 -07:00
parent 81e01efd4b
commit dca6d89da4
3 changed files with 13 additions and 5 deletions

View File

@ -387,6 +387,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
@ -589,6 +591,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int head_size_og = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);

View File

@ -73,7 +73,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;

View File

@ -895,12 +895,14 @@ def test_flash_attn_output(
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if softcap > 0.0 and dropout_p > 0.0:
pytest.skip("Softcap and dropout not supported together")
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
@ -1162,12 +1164,14 @@ def test_flash_attn_varlen_output(
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if softcap > 0.0 and dropout_p > 0.0:
pytest.skip("Softcap and dropout not supported together")
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)