From 78b6c4845ac9aa57ccf7e42cf4c7d3c4cdef14cf Mon Sep 17 00:00:00 2001 From: akhoroshev Date: Fri, 15 Mar 2024 04:18:07 +0300 Subject: [PATCH] Dynamically configure shared memory size for moe_align_block_size_kernel (#3376) --- csrc/moe_align_block_size_kernels.cu | 42 +++++++++++++++++++--------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index de6a0ec0..138615a4 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -7,10 +7,17 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -const static size_t NUM_MAX_EXPERTS = 64; #define CEILDIV(x,y) (((x) + (y) - 1) / (y)) namespace vllm { + +namespace { +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} +} + template __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, int32_t *sorted_token_ids, @@ -21,10 +28,14 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, size_t numel) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; - __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; - __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) + for (int i = 0; i < num_experts; ++i) { - tokens_cnts[threadIdx.x + 1][i] = 0; + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; } /** @@ -33,15 +44,15 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, * to expert expert_index. */ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; } __syncthreads(); // For each expert we accumulate the token counts from the different threads. - tokens_cnts[0][threadIdx.x] = 0; + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; + tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; } __syncthreads(); @@ -50,7 +61,7 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size; + cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; } *total_tokens_post_pad = cumsum[num_experts]; } @@ -78,9 +89,9 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, * stores the indices of the tokens processed by the expert with expert_id within * the current thread's token shard. */ - int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id]; + int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[threadIdx.x][expert_id]; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } } @@ -93,11 +104,16 @@ void moe_align_block_size( torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - assert(num_experts <= NUM_MAX_EXPERTS); VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - vllm::moe_align_block_size_kernel<<<1, num_experts, 0, stream>>>( - topk_ids.data_ptr(), + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors + const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + + // set dynamic shared mem + auto kernel = vllm::moe_align_block_size_kernel; + AT_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem)); + kernel<<<1, num_experts, shared_mem, stream>>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(),