diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 3c5d08a1..63542990 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -90,7 +90,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i + (((int)threadIdx.x) % (128 / 8)) * 8; half* C_ptr = C - + blockIdx_z * M * OC // blockIdz.x -> split_k dim + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * 128 + ((int)threadIdx.y) * 64 + (((int)threadIdx.x) % 4) * 2; @@ -323,7 +323,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in + (((int)threadIdx.x) % (64 / 8)) * 8; half* C_ptr = C - + blockIdx_z * M * OC // blockIdz.x -> split_k dim + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * 64 + ((int)threadIdx.y) * 32 + (((int)threadIdx.x) % 4) * 2;