Cast q.get_device() to char to avoid compiler warning (narrowing)

This commit is contained in:
Tri Dao 2022-10-24 15:59:49 -07:00
parent ed553e9238
commit 97e13de2b4

View File

@ -253,7 +253,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool loop = max_seqlen_k > blocksize_c;
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.get_device()};
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
@ -412,7 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool loop = max_seqlen_k > blocksize_c;
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.get_device()};
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();