Add type check (fp16) in the forward pass
This commit is contained in:
parent
ea38d3d261
commit
c0daa62eaa
@ -130,6 +130,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||
|
||||
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
|
||||
|
||||
TORCH_CHECK(qkv.is_cuda())
|
||||
TORCH_CHECK(cu_seqlens.is_cuda())
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user