[Kernel] reloading fused_moe config on the last chunk (#6210)
This commit is contained in:
parent
717f4bcea0
commit
f7a8fa39d8
@ -332,6 +332,31 @@ def get_default_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_optimal_moe_config(
|
||||||
|
w1_shape: Tuple[int, ...],
|
||||||
|
w2_shape: Tuple[int, ...],
|
||||||
|
top_k: int,
|
||||||
|
dtype: Optional[str],
|
||||||
|
M: int,
|
||||||
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
if override_config:
|
||||||
|
config = override_config
|
||||||
|
else:
|
||||||
|
# First try to load optimal config from the file
|
||||||
|
E, _, N = w2_shape
|
||||||
|
configs = get_moe_configs(E, N, dtype)
|
||||||
|
|
||||||
|
if configs:
|
||||||
|
# If an optimal configuration map has been found, look up the
|
||||||
|
# optimal config
|
||||||
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
|
else:
|
||||||
|
# Else use the default config
|
||||||
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def fused_topk(
|
def fused_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@ -428,22 +453,16 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
M = min(num_tokens, CHUNK_SIZE)
|
M = min(num_tokens, CHUNK_SIZE)
|
||||||
|
|
||||||
if override_config:
|
get_config_func = functools.partial(
|
||||||
config = override_config
|
try_get_optimal_moe_config,
|
||||||
else:
|
w1.shape,
|
||||||
# First try to load optimal config from the file
|
w2.shape,
|
||||||
configs = get_moe_configs(E, w2.shape[2],
|
topk_ids.shape[1],
|
||||||
"float8" if use_fp8 else None)
|
"float8" if use_fp8 else None,
|
||||||
|
override_config=override_config,
|
||||||
|
)
|
||||||
|
|
||||||
if configs:
|
config = get_config_func(M)
|
||||||
# If an optimal configuration map has been found, look up the
|
|
||||||
# optimal config
|
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
||||||
else:
|
|
||||||
# Else use the default config
|
|
||||||
config = get_default_config(M, E, N, w1.shape[2],
|
|
||||||
topk_ids.shape[1],
|
|
||||||
"float8" if use_fp8 else None)
|
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
@ -478,6 +497,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||||
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
||||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||||
|
# reload config to get better performance on the last chunk
|
||||||
|
config = get_config_func(tokens_in_chunk)
|
||||||
|
|
||||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user