From 75acdaa4b616c2e95c55a47d3158ceec9c72c503 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Sat, 27 Jul 2024 17:52:33 -0400 Subject: [PATCH] [Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795) --- benchmarks/kernels/benchmark_marlin.py | 23 ++- csrc/ops.h | 3 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 150 ++++++++++++++---- tests/kernels/test_marlin_gemm.py | 13 +- vllm/_custom_ops.py | 6 +- .../layers/quantization/utils/marlin_utils.py | 17 +- 6 files changed, 168 insertions(+), 44 deletions(-) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 3da4cecd..684985b8 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) + MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, marlin_quantize) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( @@ -56,6 +56,8 @@ def bench_run(results: List[benchmark.Measurement], model: str, (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) + # GPTQ quant (w_ref, q_w, s, g_idx, rand_perm) = quantize_weights(b, num_bits, group_size, act_order) @@ -87,6 +89,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, "marlin_w_ref": marlin_w_ref, "marlin_q_w": marlin_q_w, "marlin_s": marlin_s, + "marlin_zp": marlin_zp, "marlin_g_idx": marlin_g_idx, "marlin_sort_indices": marlin_sort_indices, "marlin_rand_perm": marlin_rand_perm, @@ -125,11 +128,21 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="gptq_marlin_gemm", + description="gptq_marlin_gemm_fp16", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_gemm_fp32", ).blocked_autorange(min_run_time=min_run_time)) if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS @@ -183,12 +196,12 @@ def main(args): ) > 0 and is_k_full not in args.limit_k_full: continue - for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + for num_bits in MARLIN_SUPPORTED_NUM_BITS: if len(args.limit_num_bits ) > 0 and num_bits not in args.limit_num_bits: continue - for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + for group_size in MARLIN_SUPPORTED_GROUP_SIZES: if len( args.limit_group_size ) > 0 and group_size not in args.limit_group_size: diff --git a/csrc/ops.h b/csrc/ops.h index 9ef1fcb4..f0758502 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -93,7 +93,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp); + bool is_k_full, bool has_zp, + bool use_fp32_reduce); torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 122c5c16..36ae2bfa 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -59,14 +59,16 @@ __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce ) {} } // namespace gptq_marlin @@ -532,16 +534,18 @@ __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -595,6 +599,8 @@ __global__ void Marlin( int slice_idx; // index of threadblock in current slice; numbered bottom to // top + int par_id = 0; + // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers if (slice_col_par >= n_tiles) { @@ -602,6 +608,7 @@ __global__ void Marlin( C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; locks += (slice_col_par / n_tiles) * n_tiles; slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; } // Compute all information about the current slice which is required for @@ -632,6 +639,7 @@ __global__ void Marlin( C += 16 * thread_m_blocks * prob_n / 8; locks += n_tiles; slice_col = 0; + par_id++; } }; init_slice(); @@ -1321,7 +1329,7 @@ __global__ void Marlin( // finally have to globally reduce over the results. As the striped // partitioning minimizes the number of such reductions and our outputs are // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). @@ -1382,6 +1390,53 @@ __global__ void Marlin( } }; + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + int par_offset = c_size * n_tiles * par_id; + int slice_offset = c_size * slice_col; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = par_offset + slice_offset; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + sh[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh[threadIdx.x]); + #pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. @@ -1606,7 +1661,11 @@ __global__ void Marlin( if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } barrier_release(&locks[slice_col], last); } if (last) // only the last block in a slice actually writes the result @@ -1661,8 +1720,8 @@ __global__ void Marlin( THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ HAS_ZP, GROUP_BLOCKS> \ <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \ - prob_m, prob_n, prob_k, locks); \ + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ + num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ } typedef struct { @@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, return true; } +int determine_reduce_max_m(int prob_m, int max_par) { + constexpr int tile_m_size = 16; + + if (prob_m <= tile_m_size) { + return tile_m_size; + + } else if (prob_m <= tile_m_size * 2) { + return tile_m_size * 2; + + } else if (prob_m <= tile_m_size * 3) { + return tile_m_size * 3; + + } else if (prob_m <= tile_m_size * 4) { + return tile_m_size * 4; + + } else { + int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par); + return tile_m_size * 4 * cur_par; + } +} + exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, @@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, - void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - bool has_act_order, bool is_k_full, bool has_zp, - int num_groups, int group_size, int dev, +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, + void* s, void* zp, void* g_idx, void* perm, void* a_tmp, + int prob_m, int prob_n, int prob_k, void* workspace, + int num_bits, bool has_act_order, bool is_k_full, + bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par) { + int max_par, bool use_fp32_reduce) { TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp, const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; @@ -2049,7 +2130,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp) { + bool is_k_full, bool has_zp, + bool use_fp32_reduce) { // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -2099,6 +2181,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor c = torch::empty({size_m, size_n}, options); torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); + // Alloc C tmp buffer that is going to be used for the global reduce + int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par); + int reduce_n = size_n; + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + if (!use_fp32_reduce) { + reduce_max_m = 0; + reduce_n = 0; + } + torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32); + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_k = -1; @@ -2171,20 +2264,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, if (a.scalar_type() == at::ScalarType::Half) { marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, + c_tmp.data_ptr(), b_scales.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else if (a.scalar_type() == at::ScalarType::BFloat16) { marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, + c.data_ptr(), c_tmp.data_ptr(), + b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 42087fdc..bd35ef2e 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] +USE_FP32_REDUCE_OPTS = [False, True] MARLIN_K_CHUNKS = [128] MARLIN_N_CHUNKS = [64, 128, 256] @@ -175,6 +176,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) +@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) def test_gptq_marlin_gemm( k_chunk, n_chunk, @@ -183,6 +185,7 @@ def test_gptq_marlin_gemm( mnk_factors, act_order, is_k_full, + use_fp32_reduce, ): m_factor, n_factor, k_factor = mnk_factors @@ -222,8 +225,9 @@ def test_gptq_marlin_gemm( a_input.shape[0], b_weight.shape[1], a_input.shape[1], - is_k_full, + is_k_full=is_k_full, has_zp=False, + use_fp32_reduce=use_fp32_reduce, ) output_ref = torch.matmul(a_input, w_ref) @@ -365,12 +369,14 @@ def test_fp8_marlin_gemm( @pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) def test_awq_marlin_gemm( k_chunk, n_chunk, num_bits, group_size, mnk_factors, + use_fp32_reduce, ): m_factor, n_factor, k_factor = mnk_factors @@ -407,8 +413,9 @@ def test_awq_marlin_gemm( a_input.shape[0], b_weight.shape[1], a_input.shape[1], - is_k_full, - has_zp, + is_k_full=is_k_full, + has_zp=has_zp, + use_fp32_reduce=use_fp32_reduce, ) output_ref = torch.matmul(a_input, w_ref) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 01865946..ad9f01be 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -286,12 +286,12 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, b_zeros: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, - size_n: int, size_k: int, is_k_full: bool, - has_zp: bool) -> torch.Tensor: + size_n: int, size_k: int, is_k_full: bool, has_zp: bool, + use_fp32_reduce: bool) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, g_idx, perm, workspace, num_bits, size_m, size_n, size_k, is_k_full, - has_zp) + has_zp, use_fp32_reduce) # fp8 marlin diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 25a7cd7b..b789ca20 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -16,6 +16,11 @@ GPTQ_MARLIN_MAX_PARALLEL = 16 MARLIN_SUPPORTED_NUM_BITS = [4, 8] MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, min_capability: Optional[int], @@ -244,7 +249,8 @@ def apply_gptq_marlin_linear( output_size_per_partition: int, input_size_per_partition: int, is_k_full: bool, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition, ) @@ -260,7 +266,8 @@ def apply_gptq_marlin_linear( size_n=output_size_per_partition, size_k=input_size_per_partition, is_k_full=is_k_full, - has_zp=False) + has_zp=False, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add @@ -279,7 +286,8 @@ def apply_awq_marlin_linear( num_bits: int, output_size_per_partition: int, input_size_per_partition: int, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition, ) @@ -295,7 +303,8 @@ def apply_awq_marlin_linear( size_n=output_size_per_partition, size_k=input_size_per_partition, is_k_full=True, - has_zp=True) + has_zp=True, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add