Don't support softcap and dropout at the same time
These tests are failing so I'm just disabling this case for now
This commit is contained in:
parent
81e01efd4b
commit
dca6d89da4
@ -387,6 +387,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
||||
|
||||
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
@ -589,6 +591,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
const int head_size_og = sizes[2];
|
||||
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
|
||||
|
||||
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
||||
|
||||
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
|
||||
const int num_blocks = !paged_KV ? 0 : k.size(0);
|
||||
const int page_block_size = !paged_KV ? 1 : k.size(1);
|
||||
|
||||
@ -73,7 +73,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
|
||||
@ -895,12 +895,14 @@ def test_flash_attn_output(
|
||||
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
||||
):
|
||||
pytest.skip() # Reference implementation OOM
|
||||
if softcap > 0.0 and dropout_p > 0.0:
|
||||
pytest.skip("Softcap and dropout not supported together")
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 4
|
||||
nheads = 9
|
||||
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
||||
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
|
||||
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
|
||||
assert nheads % nheads_k == 0
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
@ -1162,12 +1164,14 @@ def test_flash_attn_varlen_output(
|
||||
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
||||
):
|
||||
pytest.skip() # Reference implementation OOM
|
||||
if softcap > 0.0 and dropout_p > 0.0:
|
||||
pytest.skip("Softcap and dropout not supported together")
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 4
|
||||
nheads = 9
|
||||
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
||||
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
|
||||
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
|
||||
assert nheads % nheads_k == 0
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user