Add type check (fp16) in the forward pass

This commit is contained in:
Tri Dao 2022-06-26 11:41:30 -07:00
parent ea38d3d261
commit c0daa62eaa

View File

@ -130,6 +130,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
bool is_dropout = p_dropout > 0.0;
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())