From c0daa62eaafbc59a03ba470a2a05ba8d9d43b1ce Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 Jun 2022 11:41:30 -0700 Subject: [PATCH] Add type check (fp16) in the forward pass --- csrc/flash_attn/fmha_api.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 36a3692..2783cb7 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -130,6 +130,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot bool is_dropout = p_dropout > 0.0; Launch_params launch_params(dprops, stream, is_dropout, return_softmax); + TORCH_CHECK(qkv.dtype() == torch::kFloat16); + TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); + TORCH_CHECK(qkv.is_cuda()) TORCH_CHECK(cu_seqlens.is_cuda())