Cast q.get_device() to char to avoid compiler warning (narrowing)
This commit is contained in:
parent
ed553e9238
commit
97e13de2b4
@ -253,7 +253,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
at::cuda::CUDAGuard device_guard{q.get_device()};
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
@ -412,7 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
at::cuda::CUDAGuard device_guard{q.get_device()};
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
|
||||
auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user