diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 3083cdf..fc7c105 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -253,7 +253,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q bool loop = max_seqlen_k > blocksize_c; // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.get_device()}; + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); @@ -412,7 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size bool loop = max_seqlen_k > blocksize_c; // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.get_device()}; + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different. auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();