[Kernel] Fix deprecation function warnings squeezellm quant_cuda_kernel (#6901)

This commit is contained in:
Tyler Michael Smith 2024-07-29 12:59:02 -04:00 committed by GitHub
parent db9e5708a9
commit 60d1c6e584
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -197,13 +197,13 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*)vec.data<at::Half>(),
(half2*)vec.data_ptr<at::Half>(),
#else
(__half2*)vec.data_ptr<at::Half>(),
#endif
mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data_ptr<at::Half>(),
#else
(float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(),