[Kernel] reloading fused_moe config on the last chunk (#6210)

This commit is contained in:
Avshalom Manevich 2024-07-08 18:00:38 +03:00 committed by GitHub
parent 717f4bcea0
commit f7a8fa39d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -332,6 +332,31 @@ def get_default_config(
return config
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
):
if override_config:
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
configs = get_moe_configs(E, N, dtype)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
return config
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@ -428,22 +453,16 @@ def fused_experts(hidden_states: torch.Tensor,
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1.shape[2],
topk_ids.shape[1],
"float8" if use_fp8 else None)
config = get_config_func(M)
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
@ -478,6 +497,8 @@ def fused_experts(hidden_states: torch.Tensor,
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
# reload config to get better performance on the last chunk
config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]