From 6aa33cb2ddd769e764a3312627cab5bffaa383cc Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 12 Aug 2024 14:40:13 -0400 Subject: [PATCH] [Misc] Use scalar type to dispatch to different `gptq_marlin` kernels (#7323) --- csrc/core/scalar_type.hpp | 197 +++++++++-- csrc/quantization/gptq_marlin/gptq_marlin.cu | 353 +++++++++---------- 2 files changed, 332 insertions(+), 218 deletions(-) diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 9f78402e..e0f4d9f3 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -20,7 +20,7 @@ namespace vllm { // class ScalarType { public: - enum NanRepr : int64_t { + enum NanRepr : uint8_t { NAN_NONE = 0, // nans are not supported NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s @@ -28,33 +28,33 @@ class ScalarType { NAN_REPR_ID_MAX }; - constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa, - int64_t bias, bool finite_values_only = false, + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, NanRepr nan_repr = NAN_IEEE_754) : exponent(exponent), mantissa(mantissa), - bias(bias), signed_(signed_), + bias(bias), finite_values_only(finite_values_only), nan_repr(nan_repr){}; - static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) { - return ScalarType(true, 0, size_bits - 1, bias); + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); } - static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) { - return ScalarType(false, 0, size_bits, bias); + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); } // IEEE 754 compliant floating point type - static constexpr ScalarType float_IEEE754(int64_t exponent, - int64_t mantissa) { + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { TORCH_CHECK(mantissa > 0 && exponent > 0); - return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); } // IEEE 754 non-compliant floating point type - static constexpr ScalarType float_(int64_t exponent, int64_t mantissa, + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); @@ -62,36 +62,121 @@ class ScalarType { TORCH_CHECK(nan_repr != NAN_IEEE_754, "use `float_IEEE754` constructor for floating point types that " "follow IEEE 754 conventions"); - return ScalarType(true, exponent, mantissa, 0, finite_values_only, + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); } - int64_t const exponent; // size of the exponent field (0 for integer types) - int64_t const mantissa; // size of the mantissa field (size of the integer + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer // excluding the sign bit for integer types) - int64_t const bias; // stored values equal value + bias, - // used for quantized type bool const signed_; // flag if the type supports negative numbers (i.e. has a // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type // Extra Floating point info bool const finite_values_only; // i.e. no +/-inf if true NanRepr const nan_repr; // how NaNs are represented // (not applicable for integer types) - int64_t size_bits() const { return mantissa + exponent + is_signed(); } - bool is_signed() const { return signed_; } - bool is_integer() const { return exponent == 0; } - bool is_floating_point() const { return exponent > 0; } - bool is_ieee_754() const { + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, + auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; } - bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; } - bool has_infs() const { + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { return is_floating_point() && finite_values_only == false; } - bool has_bias() const { return bias != 0; } + constexpr bool has_bias() const { return bias != 0; } private: double _floating_point_max() const { @@ -131,7 +216,7 @@ class ScalarType { return *reinterpret_cast(&double_raw); } - std::variant _raw_max() const { + constexpr std::variant _raw_max() const { if (is_floating_point()) { return {_floating_point_max()}; } else { @@ -141,7 +226,7 @@ class ScalarType { } } - std::variant _raw_min() const { + constexpr std::variant _raw_min() const { if (is_floating_point()) { TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed"); @@ -168,7 +253,7 @@ class ScalarType { public: // Max representable value for this scalar type. // (accounting for bias if there is one) - std::variant max() const { + constexpr std::variant max() const { return std::visit( [this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); @@ -176,7 +261,7 @@ class ScalarType { // Min representable value for this scalar type. // (accounting for bias if there is one) - std::variant min() const { + constexpr std::variant min() const { return std::visit( [this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); @@ -215,7 +300,7 @@ class ScalarType { } } - bool operator==(ScalarType const& other) const { + constexpr bool operator==(ScalarType const& other) const { return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && finite_values_only == other.finite_values_only && @@ -240,23 +325,59 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { using Self = ScalarTypeTorch; using SelfPtr = c10::intrusive_ptr; + static void check_size_bits(int64_t size_bits, bool signed_) { + TORCH_CHECK( + size_bits <= + std::numeric_limits().mantissa)>::max(), + "size_bits bit width is too large to be represented"); + } + + static void check_bias(int64_t bias) { + using Bias = decltype(std::declval().bias); + TORCH_CHECK(bias <= std::numeric_limits::max() && + bias >= std::numeric_limits::min(), + "bias too large or small to be represented"); + } + + static void check_exponent(int64_t exponent) { + TORCH_CHECK( + exponent <= + std::numeric_limits().exponent)>::max(), + "exponent bit width is too large to be represented"); + } + + static void check_mantissa(int64_t mantissa) { + TORCH_CHECK( + mantissa <= + std::numeric_limits().mantissa)>::max(), + "mantissa bit width is too large to be represented"); + } + static SelfPtr int_(int64_t size_bits, c10::optional bias) { + check_size_bits(size_bits, true); + check_bias(bias.value_or(0)); return c10::make_intrusive( ScalarType::int_(size_bits, bias.value_or(0))); } static SelfPtr uint(int64_t size_bits, c10::optional bias) { + check_size_bits(size_bits, true); + check_bias(bias.value_or(0)); return c10::make_intrusive( ScalarType::uint(size_bits, bias.value_or(0))); } static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { + check_mantissa(mantissa); + check_exponent(exponent); return c10::make_intrusive( ScalarType::float_IEEE754(exponent, mantissa)); } static SelfPtr float_(int64_t exponent, int64_t mantissa, bool finite_values_only, int64_t nan_repr) { + check_mantissa(mantissa); + check_exponent(exponent); return c10::make_intrusive(ScalarType::float_( exponent, mantissa, finite_values_only, NanRepr(nan_repr))); } @@ -264,7 +385,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { template static void bind_readonly_property(torch::class_& cls, std::string const& name, T Base::*field) { - auto getter_func = [field = std::move(field)](SelfPtr const& self) { + auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) { if constexpr (std::is_member_function_pointer_v) { return (self.get()->*field)(); } else { @@ -272,6 +393,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { } }; + auto getter_func = [field = std::move(field), + getter_func_helper = std::move(getter_func_helper)]( + SelfPtr const& self) { + auto val = getter_func_helper(self); + // upconvert uint8_t, int32_t etc. to int64_t for python + if constexpr (std::is_integral_v) { + return static_cast(val); + } else { + return val; + } + }; + cls.def_property(name, getter_func); } @@ -340,6 +473,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { } }; +using ScalarTypeId = int64_t; using ScalarTypeTorchPtr = c10::intrusive_ptr; // "rust style" names generally following: @@ -379,4 +513,5 @@ static inline constexpr auto kHalf = kFE5M10; static inline constexpr auto kFloat16 = kHalf; static inline constexpr auto kBFloat16 = kFE8M7; +static inline constexpr auto kFloat16Id = kFloat16.id(); }; // namespace vllm diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index e2b0f2b0..9b4a6a51 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -42,8 +42,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template +__device__ inline typename ScalarType::FragB dequant(int q); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 -template -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - +// template <> -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { +__device__ inline typename ScalarType::FragB +dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -187,7 +188,7 @@ __device__ inline typename ScalarType::FragB dequant_4bit(int q) { template <> __device__ inline typename ScalarType::FragB -dequant_4bit(int q) { +dequant(int q) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; @@ -210,19 +211,64 @@ dequant_4bit(int q) { return frag_b; } +template <> +__device__ inline typename ScalarType::FragB +dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or // bf16 Reference: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -template -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - +// template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { +__device__ inline typename ScalarType::FragB +dequant(int q) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -242,7 +288,7 @@ __device__ inline typename ScalarType::FragB dequant_8bit(int q) { template <> __device__ inline typename ScalarType::FragB -dequant_8bit(int q) { +dequant(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; @@ -269,68 +315,9 @@ dequant_8bit(int q) { return frag_b; } -// Zero-point dequantizers - -template -__device__ inline typename ScalarType::FragB dequant_4bit_zp(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - template <> -__device__ inline typename ScalarType::FragB dequant_4bit_zp( - int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_4bit_zp(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC300C300; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template -__device__ inline typename ScalarType::FragB dequant_8bit_zp(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit_zp( - int q) { +__device__ inline typename ScalarType::FragB +dequant(int q) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -350,7 +337,7 @@ __device__ inline typename ScalarType::FragB dequant_8bit_zp( template <> __device__ inline typename ScalarType::FragB -dequant_8bit_zp(int q) { +dequant(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; @@ -517,8 +504,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, } } -template ::FragS; using FragZP = typename ScalarType::FragZP; - constexpr int pack_factor = 32 / num_bits; + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + + constexpr int pack_factor = 32 / w_type.size_bits(); // For larger GEMMs we run multiple batchsize 64 versions in parallel for a // better partitioning with less reductions @@ -670,7 +659,7 @@ __global__ void Marlin( // B sizes/strides int b_gl_stride = 16 * prob_n / (pack_factor * 4); constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; @@ -1186,19 +1175,20 @@ __global__ void Marlin( if constexpr (has_zp) { FragB frag_zp_0; FragB frag_zp_1; - if constexpr (num_bits == 4) { - int zp_quant = frag_qzp[k % 2][0]; - int zp_quant_shift = zp_quant >> 8; - frag_zp_0 = dequant_4bit_zp(zp_quant); - frag_zp_1 = dequant_4bit_zp(zp_quant_shift); + int zp_quant_0, zp_quant_1; + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; } else { - int zp_quant_0 = frag_qzp[k % 2][0]; - int zp_quant_1 = frag_qzp[k % 2][1]; - frag_zp_0 = dequant_8bit_zp(zp_quant_0); - frag_zp_1 = dequant_8bit_zp(zp_quant_1); + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; } + frag_zp_0 = dequant(zp_quant_0); + frag_zp_1 = dequant(zp_quant_1); + frag_zp[0] = frag_zp_0[0]; frag_zp[1] = frag_zp_0[1]; frag_zp[2] = frag_zp_1[0]; @@ -1211,33 +1201,21 @@ __global__ void Marlin( for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; - if constexpr (num_bits == 4) { - int b_quant = frag_b_quant[k % 2][0][j]; - int b_quant_shift = b_quant >> 8; - - if constexpr (has_zp) { - frag_b0 = dequant_4bit_zp(b_quant); - frag_b1 = dequant_4bit_zp(b_quant_shift); - - } else { - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); - } + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; } else { + static_assert(w_type.size_bits() == 8); int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - if constexpr (has_zp) { - frag_b0 = dequant_8bit_zp(b_quant_0); - frag_b1 = dequant_8bit_zp(b_quant_1); - } else { - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - } + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } + frag_b0 = dequant(b_quant_0); + frag_b1 = dequant(b_quant_1); + // Apply zero-point to frag_b0 if constexpr (has_zp) { sub_zp(frag_b0, frag_zp[j], 0); @@ -1477,7 +1455,8 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1605,7 +1584,7 @@ __global__ void Marlin( // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (num_bits == 8) { + if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } @@ -1622,7 +1601,7 @@ __global__ void Marlin( thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (num_bits == 8) { + if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1645,7 +1624,8 @@ __global__ void Marlin( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1714,20 +1694,19 @@ __global__ void Marlin( } } - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ - Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ <<>>( \ @@ -1923,52 +1902,52 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } - #define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - #define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, @@ -2113,23 +2092,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, if (false) { } - GPTQ_CALL_IF(4, 16, 4, 256) - GPTQ_CALL_IF(4, 8, 8, 256) - GPTQ_CALL_IF(4, 8, 4, 128) - GPTQ_CALL_IF(4, 4, 8, 128) - GPTQ_CALL_IF(8, 16, 4, 256) - GPTQ_CALL_IF(8, 8, 8, 256) - GPTQ_CALL_IF(8, 8, 4, 128) - GPTQ_CALL_IF(8, 4, 8, 128) + GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128) + GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128) - AWQ_CALL_IF(4, 16, 4, 256) - AWQ_CALL_IF(4, 8, 8, 256) - AWQ_CALL_IF(4, 8, 4, 128) - AWQ_CALL_IF(4, 4, 8, 128) - AWQ_CALL_IF(8, 16, 4, 256) - AWQ_CALL_IF(8, 8, 8, 256) - AWQ_CALL_IF(8, 8, 4, 128) - AWQ_CALL_IF(8, 4, 8, 128) + AWQ_CALL_IF(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF(vllm::kU4, 4, 8, 128) + AWQ_CALL_IF(vllm::kU8, 16, 4, 256) + AWQ_CALL_IF(vllm::kU8, 8, 8, 256) + AWQ_CALL_IF(vllm::kU8, 8, 4, 128) + AWQ_CALL_IF(vllm::kU8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]", ", has_act_order = ", has_act_order,