Add type check (fp16) in the forward pass
This commit is contained in:
parent
ea38d3d261
commit
c0daa62eaa
@ -130,6 +130,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
|
|||||||
bool is_dropout = p_dropout > 0.0;
|
bool is_dropout = p_dropout > 0.0;
|
||||||
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||||
|
|
||||||
|
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
|
||||||
|
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
|
||||||
|
|
||||||
TORCH_CHECK(qkv.is_cuda())
|
TORCH_CHECK(qkv.is_cuda())
|
||||||
TORCH_CHECK(cu_seqlens.is_cuda())
|
TORCH_CHECK(cu_seqlens.is_cuda())
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user