From 97e13de2b40a95ea8510766e8073a9ca6ba6d945 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 24 Oct 2022 15:59:49 -0700 Subject: [PATCH] Cast q.get_device() to char to avoid compiler warning (narrowing) --- csrc/flash_attn/fmha_api.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 3083cdf..fc7c105 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -253,7 +253,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q bool loop = max_seqlen_k > blocksize_c; // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.get_device()}; + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); @@ -412,7 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size bool loop = max_seqlen_k > blocksize_c; // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.get_device()}; + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different. auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();