[Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032)
Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
parent
fc990f9795
commit
a091e2da3e
@ -25,6 +25,8 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::string str(T x) {
|
inline std::string str(T x) {
|
||||||
return std::to_string(x);
|
return std::to_string(x);
|
||||||
@ -131,11 +133,26 @@ __device__ inline int lop3(int a, int b, int c) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
// Constructs destination register by taking bytes from 2 sources (based on
|
||||||
// values. We mostly follow the strategy in the link below, with some small
|
// mask)
|
||||||
// changes:
|
template <int start_byte, int mask>
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
__device__ inline uint32_t prmt(uint32_t a) {
|
||||||
__device__ inline FragB dequant(int q) {
|
uint32_t res;
|
||||||
|
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "n"(start_byte), "n"(mask));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <vllm::ScalarTypeId w_type_id>
|
||||||
|
__device__ inline FragB dequant(int q);
|
||||||
|
|
||||||
|
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||||
|
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||||
|
// with some small changes:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||||
|
template <>
|
||||||
|
__device__ inline FragB dequant<vllm::kU4B8.id()>(int q) {
|
||||||
const int LO = 0x000f000f;
|
const int LO = 0x000f000f;
|
||||||
const int HI = 0x00f000f0;
|
const int HI = 0x00f000f0;
|
||||||
const int EX = 0x64006400;
|
const int EX = 0x64006400;
|
||||||
@ -156,6 +173,28 @@ __device__ inline FragB dequant(int q) {
|
|||||||
return frag_b;
|
return frag_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16
|
||||||
|
// Reference:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||||
|
template <>
|
||||||
|
__device__ inline FragB dequant<vllm::kU8B128.id()>(int q) {
|
||||||
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
|
|
||||||
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||||
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||||
|
|
||||||
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||||
|
|
||||||
|
FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
// Multiply dequantized values by the corresponding quantization scale; used
|
// Multiply dequantized values by the corresponding quantization scale; used
|
||||||
// only for grouped quantization.
|
// only for grouped quantization.
|
||||||
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||||
@ -296,7 +335,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <const int threads, // number of threads in a threadblock
|
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
// threadblock
|
// threadblock
|
||||||
@ -331,6 +371,9 @@ __device__ inline void MarlinMoESingle(
|
|||||||
bool apply_weights, // apply weights to output
|
bool apply_weights, // apply weights to output
|
||||||
int current_m_block // current m block to start kernel computation from
|
int current_m_block // current m block to start kernel computation from
|
||||||
) {
|
) {
|
||||||
|
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||||
|
constexpr int pack_factor = 32 / w_type.size_bits();
|
||||||
|
|
||||||
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
||||||
// better partitioning with less reductions
|
// better partitioning with less reductions
|
||||||
int parallel = 1;
|
int parallel = 1;
|
||||||
@ -423,19 +466,23 @@ __device__ inline void MarlinMoESingle(
|
|||||||
constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
|
constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
|
||||||
|
|
||||||
// B sizes/strides
|
// B sizes/strides
|
||||||
int b_gl_stride = 16 * prob_n / 32;
|
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
||||||
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
|
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
||||||
|
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
|
||||||
|
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
||||||
|
|
||||||
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
||||||
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
|
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
|
||||||
constexpr int b_sh_wr_delta = threads;
|
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
|
||||||
constexpr int b_sh_rd_delta = threads;
|
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
|
||||||
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
||||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||||
|
|
||||||
// Scale sizes/strides without act_order
|
// Scale sizes/strides without act_order
|
||||||
int s_gl_stride = prob_n / 8;
|
int s_gl_stride = prob_n / 8;
|
||||||
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
||||||
constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks
|
constexpr int s_tb_groups =
|
||||||
|
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||||
? thread_k_blocks / group_blocks
|
? thread_k_blocks / group_blocks
|
||||||
: 1;
|
: 1;
|
||||||
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
||||||
@ -465,12 +512,12 @@ __device__ inline void MarlinMoESingle(
|
|||||||
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
||||||
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
||||||
|
|
||||||
int b_gl_rd =
|
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||||
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
|
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
||||||
b_gl_rd += b_sh_stride * slice_col;
|
b_gl_rd += b_sh_stride * slice_col;
|
||||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||||
int b_sh_wr = threadIdx.x;
|
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||||
int b_sh_rd = threadIdx.x;
|
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||||
|
|
||||||
// For act_order
|
// For act_order
|
||||||
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
||||||
@ -481,12 +528,14 @@ __device__ inline void MarlinMoESingle(
|
|||||||
|
|
||||||
// No act_order
|
// No act_order
|
||||||
int s_gl_rd;
|
int s_gl_rd;
|
||||||
if constexpr (group_blocks == -1 || group_blocks == 0) {
|
if constexpr (!has_act_order) {
|
||||||
|
if constexpr (group_blocks == -1) {
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
} else {
|
} else {
|
||||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||||
s_sh_stride * slice_col + threadIdx.x;
|
s_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
int s_sh_wr = threadIdx.x;
|
int s_sh_wr = threadIdx.x;
|
||||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||||
|
|
||||||
@ -571,7 +620,7 @@ __device__ inline void MarlinMoESingle(
|
|||||||
|
|
||||||
// Register storage for double buffer of shared memory reads.
|
// Register storage for double buffer of shared memory reads.
|
||||||
FragA frag_a[2][thread_m_blocks];
|
FragA frag_a[2][thread_m_blocks];
|
||||||
I4 frag_b_quant[2];
|
I4 frag_b_quant[2][b_thread_vecs];
|
||||||
FragC frag_c[thread_m_blocks][4][2];
|
FragC frag_c[thread_m_blocks][4][2];
|
||||||
FragS frag_s[2][4]; // No act-order
|
FragS frag_s[2][4]; // No act-order
|
||||||
FragS act_frag_s[2][4][4]; // For act-order
|
FragS act_frag_s[2][4][4]; // For act-order
|
||||||
@ -637,7 +686,10 @@ __device__ inline void MarlinMoESingle(
|
|||||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
|
#pragma unroll
|
||||||
|
for (int j = 0; j < b_thread_vecs; j++) {
|
||||||
|
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
||||||
|
}
|
||||||
B_ptr[i] += b_gl_rd_delta_o;
|
B_ptr[i] += b_gl_rd_delta_o;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -715,14 +767,24 @@ __device__ inline void MarlinMoESingle(
|
|||||||
for (int i = 0; i < thread_m_blocks; i++)
|
for (int i = 0; i < thread_m_blocks; i++)
|
||||||
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
||||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||||
frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
|
|
||||||
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
|
#pragma unroll
|
||||||
|
for (int i = 0; i < b_thread_vecs; i++) {
|
||||||
|
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
||||||
|
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_same_group[stages];
|
bool is_same_group[stages];
|
||||||
int same_group_id[stages];
|
int same_group_id[stages];
|
||||||
|
|
||||||
auto init_same_group = [&](int pipe) {
|
auto init_same_group = [&](int pipe) {
|
||||||
|
if constexpr (!has_act_order) {
|
||||||
|
is_same_group[pipe] = false;
|
||||||
|
same_group_id[pipe] = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
||||||
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
||||||
|
|
||||||
@ -840,10 +902,19 @@ __device__ inline void MarlinMoESingle(
|
|||||||
// dequantization and matmul operations.
|
// dequantization and matmul operations.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
int b_quant = frag_b_quant[k % 2][j];
|
int b_quant_0, b_quant_1;
|
||||||
int b_quant_shift = b_quant >> 8;
|
if constexpr (w_type.size_bits() == 4) {
|
||||||
|
b_quant_0 = frag_b_quant[k % 2][0][j];
|
||||||
|
b_quant_1 = b_quant_0 >> 8;
|
||||||
|
} else {
|
||||||
|
static_assert(w_type.size_bits() == 8);
|
||||||
|
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||||
|
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||||
|
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||||
|
}
|
||||||
|
|
||||||
FragB frag_b0 = dequant(b_quant);
|
FragB frag_b0 = dequant<w_type_id>(b_quant_0);
|
||||||
|
FragB frag_b1 = dequant<w_type_id>(b_quant_1);
|
||||||
|
|
||||||
// Apply scale to frag_b0
|
// Apply scale to frag_b0
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
@ -855,8 +926,6 @@ __device__ inline void MarlinMoESingle(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FragB frag_b1 = dequant(b_quant_shift);
|
|
||||||
|
|
||||||
// Apply scale to frag_b1
|
// Apply scale to frag_b1
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||||
@ -881,13 +950,13 @@ __device__ inline void MarlinMoESingle(
|
|||||||
// multiple warps that accumulate their partial sums of the same output
|
// multiple warps that accumulate their partial sums of the same output
|
||||||
// location; which we have to reduce over in the end. We do in shared memory.
|
// location; which we have to reduce over in the end. We do in shared memory.
|
||||||
auto thread_block_reduce = [&]() {
|
auto thread_block_reduce = [&]() {
|
||||||
constexpr int red_off = threads / b_sh_stride / 2;
|
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
||||||
if (red_off >= 1) {
|
if (red_off >= 1) {
|
||||||
int red_idx = threadIdx.x / b_sh_stride;
|
int red_idx = threadIdx.x / b_sh_stride_threads;
|
||||||
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
|
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
||||||
constexpr int red_sh_delta = b_sh_stride;
|
constexpr int red_sh_delta = b_sh_stride_threads;
|
||||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
|
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||||
(threadIdx.x % b_sh_stride);
|
(threadIdx.x % b_sh_stride_threads);
|
||||||
|
|
||||||
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
||||||
// unnecessary read or write iterations, e.g., for two warps we write only
|
// unnecessary read or write iterations, e.g., for two warps we write only
|
||||||
@ -1035,8 +1104,10 @@ __device__ inline void MarlinMoESingle(
|
|||||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
||||||
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
||||||
|
|
||||||
// For per-column quantization we finally apply the scale here
|
// For per-column quantization we finally apply the scale here (only for
|
||||||
if constexpr (!has_act_order && group_blocks == -1) {
|
// 4-bit)
|
||||||
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
|
w_type.size_bits() == 4) {
|
||||||
res = __hmul2(res, s[0]);
|
res = __hmul2(res, s[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1090,7 +1161,7 @@ __device__ inline void MarlinMoESingle(
|
|||||||
auto start_pipes = [&]() {
|
auto start_pipes = [&]() {
|
||||||
// TODO re-enable after fixing this function
|
// TODO re-enable after fixing this function
|
||||||
// fetch_sorted_ids_to_shared();
|
// fetch_sorted_ids_to_shared();
|
||||||
__syncthreads();
|
// __syncthreads();
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < stages - 1; i++) {
|
for (int i = 0; i < stages - 1; i++) {
|
||||||
@ -1166,9 +1237,15 @@ __device__ inline void MarlinMoESingle(
|
|||||||
if (slice_iters == 0) {
|
if (slice_iters == 0) {
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
bool last = slice_idx == slice_count - 1;
|
bool last = slice_idx == slice_count - 1;
|
||||||
// For per-column scales, we only fetch them here in the final step before
|
|
||||||
// write-out
|
|
||||||
if constexpr (!has_act_order && group_blocks == -1) {
|
if constexpr (!has_act_order && group_blocks == -1) {
|
||||||
|
if constexpr (w_type.size_bits() == 8) {
|
||||||
|
if (s_sh_wr_pred) {
|
||||||
|
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||||
|
}
|
||||||
|
cp_async_fence();
|
||||||
|
} else {
|
||||||
|
// For 4-bit per-column scales, we only fetch them here in the
|
||||||
|
// final step before write-out
|
||||||
if (last) {
|
if (last) {
|
||||||
if (s_sh_wr_pred) {
|
if (s_sh_wr_pred) {
|
||||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||||
@ -1176,9 +1253,19 @@ __device__ inline void MarlinMoESingle(
|
|||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
thread_block_reduce();
|
thread_block_reduce();
|
||||||
if constexpr (!has_act_order && group_blocks == -1) {
|
if constexpr (!has_act_order && group_blocks == -1) {
|
||||||
|
if constexpr (w_type.size_bits() == 8) {
|
||||||
|
cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
|
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
||||||
|
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
if (last) {
|
if (last) {
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -1188,6 +1275,32 @@ __device__ inline void MarlinMoESingle(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For 8-bit channelwise, we apply the scale before the global reduction
|
||||||
|
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||||
|
// overflow in fp16)
|
||||||
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
|
w_type.size_bits() == 8) {
|
||||||
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; j++) {
|
||||||
|
scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
|
||||||
|
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||||
|
scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
|
||||||
|
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||||
|
|
||||||
|
scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
|
||||||
|
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||||
|
scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
|
||||||
|
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||||
// block in a slice
|
// block in a slice
|
||||||
barrier_acquire(&locks[slice_col], slice_idx);
|
barrier_acquire(&locks[slice_col], slice_idx);
|
||||||
@ -1227,7 +1340,8 @@ __device__ inline void MarlinMoESingle(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <const int threads, // number of threads in a threadblock
|
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
// threadblock
|
// threadblock
|
||||||
@ -1261,7 +1375,8 @@ __global__ void MarlinMoE(
|
|||||||
bool replicate_input, // do we use the same input for each expert?
|
bool replicate_input, // do we use the same input for each expert?
|
||||||
bool apply_weights, // apply weights to output
|
bool apply_weights, // apply weights to output
|
||||||
int current_m_block, // current m block to start kernel computation from
|
int current_m_block, // current m block to start kernel computation from
|
||||||
int max_par // maximum parallelism
|
int max_par, // maximum parallelism
|
||||||
|
int cfg_max_m_blocks // upper bound on m blocks
|
||||||
) {
|
) {
|
||||||
int m_block_ctr = current_m_block;
|
int m_block_ctr = current_m_block;
|
||||||
|
|
||||||
@ -1282,40 +1397,40 @@ __global__ void MarlinMoE(
|
|||||||
prob_m = tot_its - 16 * m_block_ctr;
|
prob_m = tot_its - 16 * m_block_ctr;
|
||||||
|
|
||||||
int par = 1;
|
int par = 1;
|
||||||
if (max_block > 4) {
|
if (max_block > cfg_max_m_blocks) {
|
||||||
// Note that parallel > 1 currently only works for inputs without any
|
// Note that parallel > 1 currently only works for inputs without any
|
||||||
// padding
|
// padding
|
||||||
par = (16 * max_block - pad) / 64;
|
par = (16 * max_block - pad) / (16 * cfg_max_m_blocks);
|
||||||
par = min((16 * max_block - pad) / 64, max_par);
|
if (par > max_par) par = max_par;
|
||||||
prob_m = 64 * par;
|
prob_m = (16 * cfg_max_m_blocks) * par;
|
||||||
m_block_ctr += 4 * (par - 1);
|
m_block_ctr += cfg_max_m_blocks * (par - 1);
|
||||||
max_block = 4;
|
max_block = cfg_max_m_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (max_block == 1) {
|
if (max_block == 1) {
|
||||||
MarlinMoESingle<threads, 1, thread_n_blocks, thread_k_blocks, stages,
|
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
|
||||||
has_act_order, group_blocks>(
|
stages, has_act_order, group_blocks>(
|
||||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||||
current_m_block);
|
current_m_block);
|
||||||
} else if (max_block == 2) {
|
} else if (max_block == 2) {
|
||||||
MarlinMoESingle<threads, 2, thread_n_blocks, thread_k_blocks, stages,
|
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
|
||||||
has_act_order, group_blocks>(
|
stages, has_act_order, group_blocks>(
|
||||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||||
current_m_block);
|
current_m_block);
|
||||||
} else if (max_block == 3) {
|
} else if (max_block == 3) {
|
||||||
MarlinMoESingle<threads, 3, thread_n_blocks, thread_k_blocks, stages,
|
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
|
||||||
has_act_order, group_blocks>(
|
stages, has_act_order, group_blocks>(
|
||||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||||
current_m_block);
|
current_m_block);
|
||||||
} else {
|
} else {
|
||||||
MarlinMoESingle<threads, 4, thread_n_blocks, thread_k_blocks, stages,
|
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
|
||||||
has_act_order, group_blocks>(
|
stages, has_act_order, group_blocks>(
|
||||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||||
@ -1342,7 +1457,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <const int threads, // number of threads in a threadblock
|
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const int threads, // number of threads in a threadblock
|
||||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
// dimension (batchsize) of the
|
// dimension (batchsize) of the
|
||||||
// threadblock
|
// threadblock
|
||||||
@ -1376,7 +1492,9 @@ __global__ void MarlinMoE(
|
|||||||
bool replicate_input, // do we use the same input for each expert?
|
bool replicate_input, // do we use the same input for each expert?
|
||||||
bool apply_weights, // apply weights to output
|
bool apply_weights, // apply weights to output
|
||||||
int current_m_block, // current m block to start kernel computation from
|
int current_m_block, // current m block to start kernel computation from
|
||||||
int max_par // maximum parallelism
|
int max_par, // maximum parallelism
|
||||||
|
int cfg_max_m_blocks // upper bound on m blocks
|
||||||
|
|
||||||
) {
|
) {
|
||||||
// Marlin is not implemented yet for SM < 8.0
|
// Marlin is not implemented yet for SM < 8.0
|
||||||
assert(false);
|
assert(false);
|
||||||
@ -1397,24 +1515,26 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
|||||||
static constexpr int min_thread_n = 64;
|
static constexpr int min_thread_n = 64;
|
||||||
static constexpr int min_thread_k = 64;
|
static constexpr int min_thread_k = 64;
|
||||||
|
|
||||||
#define __CALL_IF_MOE(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
#define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||||
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \
|
||||||
else if (thread_m_blocks == THREAD_M_BLOCKS && \
|
NUM_THREADS) \
|
||||||
|
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||||
num_threads == NUM_THREADS) { \
|
num_threads == NUM_THREADS) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||||
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||||
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
||||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||||
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
||||||
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
||||||
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
|
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
|
||||||
replicate_input, apply_weights, m_block, max_par); \
|
replicate_input, apply_weights, m_block, max_par, \
|
||||||
|
exec_cfg.max_m_blocks); \
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -1423,6 +1543,11 @@ typedef struct {
|
|||||||
int num_threads;
|
int num_threads;
|
||||||
} thread_config_t;
|
} thread_config_t;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int max_m_blocks;
|
||||||
|
thread_config_t tb_cfg;
|
||||||
|
} exec_config_t;
|
||||||
|
|
||||||
thread_config_t small_batch_thread_configs[] = {
|
thread_config_t small_batch_thread_configs[] = {
|
||||||
// Ordered by priority
|
// Ordered by priority
|
||||||
|
|
||||||
@ -1443,8 +1568,77 @@ thread_config_t large_batch_thread_configs[] = {
|
|||||||
{128, 64, 128}, // Reduce N 4X, increase K 2X
|
{128, 64, 128}, // Reduce N 4X, increase K 2X
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
|
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||||
int prob_k) {
|
int prob_n, int prob_k, int num_bits, int group_size,
|
||||||
|
bool has_act_order, bool is_k_full) {
|
||||||
|
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||||
|
|
||||||
|
int tb_n = th_config.thread_n;
|
||||||
|
int tb_k = th_config.thread_k;
|
||||||
|
|
||||||
|
// Get max scale groups per thread-block
|
||||||
|
int tb_groups;
|
||||||
|
if (group_size == -1) {
|
||||||
|
tb_groups = 1;
|
||||||
|
} else if (group_size == 0) {
|
||||||
|
tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size
|
||||||
|
} else {
|
||||||
|
tb_groups = ceildiv(tb_k, group_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cache_scales_chunk) {
|
||||||
|
int load_groups =
|
||||||
|
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
|
||||||
|
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||||
|
return load_groups * tb_n * 2;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int tb_scales = tb_groups * tb_n * 2;
|
||||||
|
|
||||||
|
return tb_scales * STAGES;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
||||||
|
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||||
|
int scales_cache_size, int max_shared_mem) {
|
||||||
|
int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
|
// Get B size
|
||||||
|
int tb_k = th_config.thread_k;
|
||||||
|
int tb_n = th_config.thread_n;
|
||||||
|
|
||||||
|
int b_size = (tb_k * tb_n / pack_factor) * 4;
|
||||||
|
|
||||||
|
// Get A size
|
||||||
|
int m_blocks = ceildiv(prob_m, 16);
|
||||||
|
int tb_max_m = 16;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (m_blocks >= max_m_blocks) {
|
||||||
|
tb_max_m *= max_m_blocks;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
max_m_blocks--;
|
||||||
|
if (max_m_blocks == 0) {
|
||||||
|
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int a_size = (tb_max_m * tb_k) * 2;
|
||||||
|
|
||||||
|
float pipe_size = (a_size + b_size) * STAGES;
|
||||||
|
|
||||||
|
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
||||||
|
|
||||||
|
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
||||||
|
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||||
|
int group_size, bool has_act_order, bool is_k_full,
|
||||||
|
int max_shared_mem) {
|
||||||
// Sanity
|
// Sanity
|
||||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||||
th_config.num_threads == -1) {
|
th_config.num_threads == -1) {
|
||||||
@ -1472,64 +1666,88 @@ bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine cache for scales
|
||||||
|
int scales_cache_size =
|
||||||
|
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||||
|
group_size, has_act_order, is_k_full);
|
||||||
|
|
||||||
|
// Check that pipeline fits into cache
|
||||||
|
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
||||||
|
num_bits, scales_cache_size, max_shared_mem)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||||
|
int num_bits, int group_size,
|
||||||
|
bool has_act_order, bool is_k_full,
|
||||||
|
int max_shared_mem) {
|
||||||
|
int max_m_blocks = 4;
|
||||||
|
while (max_m_blocks > 0) {
|
||||||
if (prob_m <= 16) {
|
if (prob_m <= 16) {
|
||||||
for (auto th_config : small_batch_thread_configs) {
|
for (auto th_config : small_batch_thread_configs) {
|
||||||
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
||||||
return th_config;
|
num_bits, group_size, has_act_order, is_k_full,
|
||||||
|
max_shared_mem)) {
|
||||||
|
return exec_config_t{max_m_blocks, th_config};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
for (auto th_config : large_batch_thread_configs) {
|
for (auto th_config : large_batch_thread_configs) {
|
||||||
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
||||||
return th_config;
|
num_bits, group_size, has_act_order, is_k_full,
|
||||||
|
max_shared_mem)) {
|
||||||
|
return exec_config_t{max_m_blocks, th_config};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return thread_config_t{-1, -1, -1};
|
max_m_blocks--; // Process less M blocks per invocation to reduce cache
|
||||||
|
// usage
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
return exec_config_t{0, {-1, -1, -1}};
|
||||||
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
}
|
||||||
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
||||||
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||||
|
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||||
|
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||||
|
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||||
\
|
\
|
||||||
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||||
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
||||||
|
|
||||||
void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
|
void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
|
||||||
const void* sorted_ids, const void* topk_weights,
|
const void* sorted_ids, const void* topk_weights,
|
||||||
const void* topk_ids, const void* s, const void* g_idx,
|
const void* topk_ids, const void* s, const void* g_idx,
|
||||||
const void* perm, void* a_tmp, void* expert_offsets,
|
const void* perm, void* a_tmp, void* expert_offsets,
|
||||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||||
bool has_act_order, bool is_k_full, int num_groups,
|
vllm::ScalarType const& q_type, bool has_act_order,
|
||||||
int group_size, int num_experts, int topk,
|
bool is_k_full, int num_groups, int group_size,
|
||||||
int moe_block_size, int dev, cudaStream_t stream,
|
int num_experts, int topk, int moe_block_size, int dev,
|
||||||
int thread_k, int thread_n, int sms, int max_par,
|
cudaStream_t stream, int thread_k, int thread_n,
|
||||||
bool replicate_input, bool apply_weights) {
|
int sms, int max_par, bool replicate_input,
|
||||||
|
bool apply_weights) {
|
||||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||||
", ", prob_n, ", ", prob_k, "]");
|
", ", prob_n, ", ", prob_k, "]");
|
||||||
|
|
||||||
@ -1537,26 +1755,42 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
|
|||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int max_shared_mem = 0;
|
||||||
|
cudaDeviceGetAttribute(&max_shared_mem,
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||||
|
TORCH_CHECK(max_shared_mem > 0);
|
||||||
|
|
||||||
|
int num_bits = q_type.size_bits();
|
||||||
|
|
||||||
// Set thread config
|
// Set thread config
|
||||||
thread_config_t th_config;
|
exec_config_t exec_cfg;
|
||||||
if (thread_k != -1 && thread_n != -1) {
|
if (thread_k != -1 && thread_n != -1) {
|
||||||
// User-defined config
|
// User-defined config
|
||||||
th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
|
exec_cfg =
|
||||||
|
exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}};
|
||||||
} else {
|
} else {
|
||||||
// Auto config
|
// Auto config
|
||||||
th_config = determine_thread_config(prob_m, prob_n, prob_k);
|
exec_cfg =
|
||||||
|
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
|
||||||
|
has_act_order, is_k_full, max_shared_mem);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k),
|
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
|
||||||
"Invalid thread config: thread_k = " + str(th_config.thread_k) +
|
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
|
||||||
", thread_n = " + str(th_config.thread_n) +
|
prob_m, prob_n, prob_k, num_bits, group_size,
|
||||||
", num_threads = " + str(th_config.num_threads) +
|
has_act_order, is_k_full, max_shared_mem),
|
||||||
" for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " +
|
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
|
||||||
str(prob_n) + "]");
|
", thread_k = ", exec_cfg.tb_cfg.thread_k,
|
||||||
|
", thread_n = ", exec_cfg.tb_cfg.thread_n,
|
||||||
|
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
|
||||||
|
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||||
|
", group_size = ", group_size,
|
||||||
|
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||||
|
", max_shared_mem = ", max_shared_mem);
|
||||||
|
|
||||||
int num_threads = th_config.num_threads;
|
int num_threads = exec_cfg.tb_cfg.num_threads;
|
||||||
thread_k = th_config.thread_k;
|
thread_k = exec_cfg.tb_cfg.thread_k;
|
||||||
thread_n = th_config.thread_n;
|
thread_n = exec_cfg.tb_cfg.thread_n;
|
||||||
|
|
||||||
int thread_k_blocks = thread_k / 16;
|
int thread_k_blocks = thread_k / 16;
|
||||||
int thread_n_blocks = thread_n / 16;
|
int thread_n_blocks = thread_n / 16;
|
||||||
@ -1590,11 +1824,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int max_shared_mem = 0;
|
|
||||||
cudaDeviceGetAttribute(&max_shared_mem,
|
|
||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
|
||||||
TORCH_CHECK(max_shared_mem > 0);
|
|
||||||
|
|
||||||
int tot_m = prob_m;
|
int tot_m = prob_m;
|
||||||
|
|
||||||
const int* topk_ids_ptr = (const int*)topk_ids;
|
const int* topk_ids_ptr = (const int*)topk_ids;
|
||||||
@ -1611,10 +1840,13 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
|
|||||||
has_act_order = false;
|
has_act_order = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int pack_factor = 32 / q_type.size_bits();
|
||||||
|
|
||||||
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
|
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
|
||||||
const int4* A_ptr = (const int4*)A;
|
const int4* A_ptr = (const int4*)A;
|
||||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||||
const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx;
|
const int4* B_ptr =
|
||||||
|
(const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx;
|
||||||
int4* C_ptr = (int4*)C;
|
int4* C_ptr = (int4*)C;
|
||||||
const float* topk_weights_ptr = (const float*)topk_weights;
|
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||||
const int* sorted_ids_ptr = (const int*)sorted_ids;
|
const int* sorted_ids_ptr = (const int*)sorted_ids;
|
||||||
@ -1636,19 +1868,22 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
|
|||||||
A_ptr = a_tmp_ptr;
|
A_ptr = a_tmp_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
int max_m_blocks = ceildiv(tot_m, 16);
|
int tot_m_blocks = ceildiv(tot_m, 16);
|
||||||
for (int m_block = 0; m_block < max_m_blocks; m_block += 16) {
|
for (int m_block = 0; m_block < tot_m_blocks;
|
||||||
// Define kernel configurations
|
m_block += 4 * exec_cfg.max_m_blocks) {
|
||||||
|
|
||||||
// make it max possible value
|
// make it max possible value
|
||||||
int thread_m_blocks = 4;
|
int thread_m_blocks = exec_cfg.max_m_blocks;
|
||||||
|
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
CALL_IF_MOE(16, 4, 256)
|
CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
|
||||||
CALL_IF_MOE(8, 8, 256)
|
CALL_IF_MOE(vllm::kU4B8, 8, 8, 256)
|
||||||
CALL_IF_MOE(8, 4, 128)
|
CALL_IF_MOE(vllm::kU4B8, 8, 4, 128)
|
||||||
CALL_IF_MOE(4, 8, 128)
|
CALL_IF_MOE(vllm::kU4B8, 4, 8, 128)
|
||||||
|
CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
|
||||||
|
CALL_IF_MOE(vllm::kU8B128, 8, 8, 256)
|
||||||
|
CALL_IF_MOE(vllm::kU8B128, 8, 4, 128)
|
||||||
|
CALL_IF_MOE(vllm::kU8B128, 4, 8, 128)
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
||||||
str(prob_n) + ", " + str(prob_k) + "]" +
|
str(prob_n) + ", " + str(prob_k) + "]" +
|
||||||
@ -1670,9 +1905,15 @@ torch::Tensor marlin_gemm_moe(
|
|||||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||||
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
|
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||||
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
|
||||||
|
int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||||
bool replicate_input, bool apply_weights) {
|
bool replicate_input, bool apply_weights) {
|
||||||
|
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||||
|
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
|
||||||
|
|
||||||
|
int pack_factor = 32 / b_q_type->size_bits();
|
||||||
|
|
||||||
int max_par = 4;
|
int max_par = 4;
|
||||||
|
|
||||||
int dev = a.get_device();
|
int dev = a.get_device();
|
||||||
@ -1733,8 +1974,8 @@ torch::Tensor marlin_gemm_moe(
|
|||||||
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
|
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
|
||||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
|
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
|
||||||
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
|
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
|
||||||
has_act_order, is_k_full, num_groups, group_size, num_experts, topk,
|
*b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts,
|
||||||
moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||||
thread_n, sms, max_par, replicate_input, apply_weights);
|
thread_n, sms, max_par, replicate_input, apply_weights);
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,11 +2,14 @@
|
|||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
torch::Tensor marlin_gemm_moe(
|
torch::Tensor marlin_gemm_moe(
|
||||||
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
||||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||||
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
|
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||||
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
|
||||||
|
int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||||
bool replicate_input, bool apply_weights);
|
bool replicate_input, bool apply_weights);
|
||||||
|
|||||||
@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||||
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
"g_idx, Tensor! perm, Tensor! workspace, "
|
||||||
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
|
||||||
"bool replicate_input, bool apply_weights) -> Tensor");
|
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
|
||||||
|
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||||
|
" -> Tensor");
|
||||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref):
|
|||||||
@pytest.mark.parametrize("topk", [2, 6])
|
@pytest.mark.parametrize("topk", [2, 6])
|
||||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||||
@pytest.mark.parametrize("act_order", [True, False])
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
|
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||||
def test_fused_marlin_moe(
|
def test_fused_marlin_moe(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -148,6 +149,7 @@ def test_fused_marlin_moe(
|
|||||||
topk: int,
|
topk: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
act_order: bool,
|
act_order: bool,
|
||||||
|
num_bits: int,
|
||||||
):
|
):
|
||||||
torch.manual_seed(7)
|
torch.manual_seed(7)
|
||||||
|
|
||||||
@ -161,13 +163,12 @@ def test_fused_marlin_moe(
|
|||||||
if group_size in (k, n):
|
if group_size in (k, n):
|
||||||
return
|
return
|
||||||
|
|
||||||
quant_type = scalar_types.uint4b8
|
quant_type = (scalar_types.uint4b8
|
||||||
|
if num_bits == 4 else scalar_types.uint8b128)
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
for i in range(w2.shape[0]):
|
|
||||||
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
|
|
||||||
|
|
||||||
w_ref1_l = []
|
w_ref1_l = []
|
||||||
qweight1_l = []
|
qweight1_l = []
|
||||||
@ -240,6 +241,7 @@ def test_fused_marlin_moe(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale=scales1,
|
w1_scale=scales1,
|
||||||
w2_scale=scales2,
|
w2_scale=scales2,
|
||||||
|
num_bits=num_bits,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
||||||
@ -254,7 +256,8 @@ def test_fused_marlin_moe(
|
|||||||
@pytest.mark.parametrize("topk", [2, 6])
|
@pytest.mark.parametrize("topk", [2, 6])
|
||||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||||
@pytest.mark.parametrize("act_order", [True, False])
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
def test_marlin_moe_mmm(
|
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||||
|
def test_single_marlin_moe_multiply(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
@ -262,6 +265,7 @@ def test_marlin_moe_mmm(
|
|||||||
topk: int,
|
topk: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
act_order: bool,
|
act_order: bool,
|
||||||
|
num_bits: int,
|
||||||
):
|
):
|
||||||
if topk > e:
|
if topk > e:
|
||||||
return
|
return
|
||||||
@ -273,7 +277,8 @@ def test_marlin_moe_mmm(
|
|||||||
if group_size == k:
|
if group_size == k:
|
||||||
return
|
return
|
||||||
|
|
||||||
quant_type = scalar_types.uint4b8
|
quant_type = (scalar_types.uint4b8
|
||||||
|
if num_bits == 4 else scalar_types.uint8b128)
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||||
@ -308,7 +313,8 @@ def test_marlin_moe_mmm(
|
|||||||
g_idx,
|
g_idx,
|
||||||
sort_indices,
|
sort_indices,
|
||||||
topk,
|
topk,
|
||||||
renormalize=False)
|
renormalize=False,
|
||||||
|
num_bits=num_bits)
|
||||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||||
|
|
||||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||||
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
|
||||||
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
||||||
0
tests/weight_loading/run_model_weight_loading_test.sh
Normal file → Executable file
0
tests/weight_loading/run_model_weight_loading_test.sh
Normal file → Executable file
@ -559,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
|||||||
num_bits: int) -> torch.Tensor:
|
num_bits: int) -> torch.Tensor:
|
||||||
num_experts = b_q_weight.shape[0]
|
num_experts = b_q_weight.shape[0]
|
||||||
assert size_k % 16 == 0
|
assert size_k % 16 == 0
|
||||||
output = torch.empty((num_experts, size_k // 16, size_n * 2),
|
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
|
||||||
device=b_q_weight.device,
|
device=b_q_weight.device,
|
||||||
dtype=b_q_weight.dtype)
|
dtype=b_q_weight.dtype)
|
||||||
for e in range(num_experts):
|
for e in range(num_experts):
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
|
||||||
def single_marlin_moe(
|
def single_marlin_moe(
|
||||||
@ -18,7 +19,9 @@ def single_marlin_moe(
|
|||||||
perm: torch.Tensor,
|
perm: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor:
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
num_bits: int = 8,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes the multiplication of hidden_states with expert
|
This function computes the multiplication of hidden_states with expert
|
||||||
weights used in Marlin MoE, using weights w and top-k gating mechanism.
|
weights used in Marlin MoE, using weights w and top-k gating mechanism.
|
||||||
@ -36,6 +39,7 @@ def single_marlin_moe(
|
|||||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||||
for the kernel configuration.
|
for the kernel configuration.
|
||||||
|
- num_bits (bool): The number of bits in expert weights quantization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
@ -48,10 +52,11 @@ def single_marlin_moe(
|
|||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w.is_contiguous(), "Expert weights must be contiguous"
|
assert w.is_contiguous(), "Expert weights must be contiguous"
|
||||||
assert hidden_states.dtype == torch.float16
|
assert hidden_states.dtype == torch.float16
|
||||||
|
assert num_bits in [4, 8]
|
||||||
|
|
||||||
M, K = hidden_states.shape
|
M, K = hidden_states.shape
|
||||||
E = w.shape[0]
|
E = w.shape[0]
|
||||||
N = w.shape[2] // 2
|
N = w.shape[2] // (num_bits // 2)
|
||||||
|
|
||||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||||
renormalize)
|
renormalize)
|
||||||
@ -76,10 +81,13 @@ def single_marlin_moe(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
scalar_type = (scalar_types.uint4b8
|
||||||
|
if num_bits == 4 else scalar_types.uint8b128)
|
||||||
|
|
||||||
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
||||||
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
||||||
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True,
|
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
|
||||||
False)
|
block_size_m, True, False)
|
||||||
|
|
||||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||||
|
|
||||||
@ -98,6 +106,7 @@ def fused_marlin_moe(
|
|||||||
override_config: Optional[Dict[str, Any]] = None,
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
num_bits: int = 8,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@ -122,6 +131,7 @@ def fused_marlin_moe(
|
|||||||
w1.
|
w1.
|
||||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
w2.
|
w2.
|
||||||
|
- num_bits (bool): The number of bits in expert weights quantization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
@ -131,13 +141,14 @@ def fused_marlin_moe(
|
|||||||
0], "Number of tokens mismatch"
|
0], "Number of tokens mismatch"
|
||||||
assert hidden_states.shape[
|
assert hidden_states.shape[
|
||||||
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||||
assert hidden_states.shape[
|
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||||
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
|
num_bits // 2), "Hidden size mismatch w2"
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
assert hidden_states.dtype == torch.float16
|
assert hidden_states.dtype == torch.float16
|
||||||
|
assert num_bits in [4, 8]
|
||||||
|
|
||||||
M, K = hidden_states.shape
|
M, K = hidden_states.shape
|
||||||
E = w1.shape[0]
|
E = w1.shape[0]
|
||||||
@ -165,6 +176,9 @@ def fused_marlin_moe(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
scalar_type = (scalar_types.uint4b8
|
||||||
|
if num_bits == 4 else scalar_types.uint8b128)
|
||||||
|
|
||||||
intermediate_cache2 = torch.empty(
|
intermediate_cache2 = torch.empty(
|
||||||
(M * topk_ids.shape[1], N),
|
(M * topk_ids.shape[1], N),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
@ -181,6 +195,7 @@ def fused_marlin_moe(
|
|||||||
g_idx1,
|
g_idx1,
|
||||||
perm1,
|
perm1,
|
||||||
workspace,
|
workspace,
|
||||||
|
scalar_type,
|
||||||
M,
|
M,
|
||||||
2 * N,
|
2 * N,
|
||||||
K,
|
K,
|
||||||
@ -204,6 +219,7 @@ def fused_marlin_moe(
|
|||||||
g_idx2,
|
g_idx2,
|
||||||
perm2,
|
perm2,
|
||||||
workspace,
|
workspace,
|
||||||
|
scalar_type,
|
||||||
M,
|
M,
|
||||||
K,
|
K,
|
||||||
N,
|
N,
|
||||||
|
|||||||
@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
return topk_weights, topk_ids.to(torch.int32)
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
def get_config_dtype_str(dtype: torch.dtype,
|
def get_config_dtype_str(dtype: torch.dtype,
|
||||||
|
|||||||
@ -6,6 +6,8 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
WNA16_SUPPORTED_BITS)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
CompressionFormat)
|
CompressionFormat)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -38,10 +40,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
if not (self.quant_config.quant_format
|
if not (self.quant_config.quant_format
|
||||||
== CompressionFormat.pack_quantized.value
|
== CompressionFormat.pack_quantized.value
|
||||||
and self.num_bits == 4):
|
and self.num_bits in WNA16_SUPPORTED_BITS):
|
||||||
raise ValueError("For Fused MoE layers, only ",
|
raise ValueError("For Fused MoE layers, only ",
|
||||||
f"{CompressionFormat.pack_quantized.value} ",
|
f"{CompressionFormat.pack_quantized.value} ",
|
||||||
"is supported for 4 bits")
|
"is supported for the following bits: ",
|
||||||
|
f"{WNA16_SUPPORTED_BITS}")
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size: int,
|
hidden_size: int, intermediate_size: int,
|
||||||
@ -292,4 +295,5 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
num_bits=self.num_bits,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -611,4 +611,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale=layer.w13_scales,
|
w1_scale=layer.w13_scales,
|
||||||
w2_scale=layer.w2_scales,
|
w2_scale=layer.w2_scales,
|
||||||
|
num_bits=self.quant_config.quant_type.size_bits,
|
||||||
).to(orig_dtype)
|
).to(orig_dtype)
|
||||||
|
|||||||
@ -23,13 +23,7 @@ def get_model_architecture(
|
|||||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||||
# Special handling for quantized Mixtral.
|
# Special handling for quantized Mixtral.
|
||||||
# FIXME(woosuk): This is a temporary hack.
|
# FIXME(woosuk): This is a temporary hack.
|
||||||
mixtral_supported = ["fp8", "compressed-tensors"]
|
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"]
|
||||||
# for gptq_marlin, only run fused MoE for int4
|
|
||||||
if model_config.quantization == "gptq_marlin":
|
|
||||||
hf_quant_config = getattr(model_config.hf_config,
|
|
||||||
"quantization_config", None)
|
|
||||||
if hf_quant_config and hf_quant_config.get("bits") == 4:
|
|
||||||
mixtral_supported.append("gptq_marlin")
|
|
||||||
|
|
||||||
if (model_config.quantization is not None
|
if (model_config.quantization is not None
|
||||||
and model_config.quantization not in mixtral_supported
|
and model_config.quantization not in mixtral_supported
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user