diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9c6bff00..fdc0ebef 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -20,6 +20,11 @@ */ #include "gptq_marlin.cuh" +#include "gptq_marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); template inline std::string str(T x) { return std::to_string(x); } @@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int4 *__restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template ; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, - FragC &frag_c) { +template +__device__ inline void mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) { const uint32_t *a = reinterpret_cast(&a_frag); const uint32_t *b = reinterpret_cast(&frag_b); float *c = reinterpret_cast(&frag_c); - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { +template +__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const void *smem_ptr) { uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" @@ -129,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) { // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_4bit(int q) { +// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +template +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -142,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) { const int SUB = 0x64086408; const int MUL = 0x2c002c00; const int ADD = 0xd480d480; - FragB frag_b; + typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), @@ -151,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) { return frag_b; } -__device__ inline FragB dequant_8bit(int q) { +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16 +// Reference: +// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -161,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - FragB frag_b; + typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); frag_b[1] = __hsub2(*reinterpret_cast(&hi), @@ -169,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) { return frag_b; } +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t * fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +template +__device__ inline void scale(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2, - FragS &frag_s_3, FragS &frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i]; +template +__device__ inline void scale4(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i]; + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } // Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float *c, FragS &s) { - __half *s_ptr = reinterpret_cast<__half *>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +template +__device__ inline void scale_float(float *c, typename ScalarType::FragS &s) { + scalar_t *s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -287,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template ; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; constexpr int pack_factor = 32 / num_bits; @@ -691,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int4 *sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll @@ -835,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int b_quant = frag_b_quant[k % 2][0][j]; int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); } else { int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); } // Apply scale to frag_b0 if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); } else { if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); + scale(frag_b0, frag_s[k % 2][j], 0); } } // Apply scale to frag_b1 if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); } else { if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); + scale(frag_b1, frag_s[k % 2][j], 1); } } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); } } }; @@ -979,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk for (int j = 0; j < 2 * 4; j++) { reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half *>(&c_red)[j]); + Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half *>(&c)[j] = - __float2half(reinterpret_cast( + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -1022,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS &s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) @@ -1030,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk res = __hmul2(res, s[0]); } - ((half2 *)sh)[idx] = res; + ((scalar_t2 *)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1192,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } } @@ -1255,10 +1349,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ - Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ <<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ @@ -1462,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +template void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, void *g_idx, void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, void *workspace, int num_bits, @@ -1731,14 +1826,25 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + if (a.scalar_type() == at::ScalarType::Half) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + } return c; } -#endif +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh new file mode 100644 index 00000000..7881abbe --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh @@ -0,0 +1,62 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include "gptq_marlin.cuh" +#include +#include + + +namespace gptq_marlin { + +template +class ScalarType { +}; + +template <> +class ScalarType { +public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { return __half2float(x); } + + static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } + + static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } +}; + +template <> +class ScalarType { +public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } +#endif +}; + +} + +#endif diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index db55d448..1fc0b3f2 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -14,6 +14,7 @@ import pytest import torch from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT from .utils import check_logprobs_close @@ -52,7 +53,7 @@ MODELS = [ @pytest.mark.skipif(gptq_marlin_not_supported, reason="gptq_marlin is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["half", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models( @@ -76,11 +77,15 @@ def test_models( gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) del gptq_marlin_model + _ROPE_DICT.clear() # clear rope cache to avoid rope dtype error # Run gptq. + # The naive gptq kernel doesn't support bf16 yet. + # Here we always compare fp16/bf16 gpt marlin kernel + # to fp16 gptq kernel. gptq_model = vllm_runner(model_name=model_name, revision=revision, - dtype=dtype, + dtype="half", quantization="gptq", max_model_len=MAX_MODEL_LEN, tensor_parallel_size=1) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e2464008..354bb55d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -99,7 +99,7 @@ class GPTQMarlinConfig(QuantizationConfig): @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] + return [torch.half, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: @@ -186,9 +186,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase): group_size = input_size # Validate dtype - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") + if params_dtype not in [torch.float16, torch.bfloat16]: + raise ValueError(f"The params dtype must be float16 " + f"or bfloat16, but got {params_dtype}") # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes)