From f7a8fa39d828136d0b8bbdd20c262602e5543ffd Mon Sep 17 00:00:00 2001 From: Avshalom Manevich <12231371+avshalomman@users.noreply.github.com> Date: Mon, 8 Jul 2024 18:00:38 +0300 Subject: [PATCH] [Kernel] reloading fused_moe config on the last chunk (#6210) --- .../layers/fused_moe/fused_moe.py | 51 +++++++++++++------ 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 99a5c7d7..a29622b7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -332,6 +332,31 @@ def get_default_config( return config +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, +): + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + configs = get_moe_configs(E, N, dtype) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) + return config + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -428,22 +453,16 @@ def fused_experts(hidden_states: torch.Tensor, CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], - "float8" if use_fp8 else None) + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + ) - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config(M, E, N, w1.shape[2], - topk_ids.shape[1], - "float8" if use_fp8 else None) + config = get_config_func(M) intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, @@ -478,6 +497,8 @@ def fused_experts(hidden_states: torch.Tensor, intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + # reload config to get better performance on the last chunk + config = get_config_func(tokens_in_chunk) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]