diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 36a3692..2783cb7 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -130,6 +130,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot bool is_dropout = p_dropout > 0.0; Launch_params launch_params(dprops, stream, is_dropout, return_softmax); + TORCH_CHECK(qkv.dtype() == torch::kFloat16); + TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); + TORCH_CHECK(qkv.is_cuda()) TORCH_CHECK(cu_seqlens.is_cuda())