[misc] update utils to support comparing multiple settings (#9140)
This commit is contained in:
parent
8eeb857084
commit
04c12f8157
@ -310,14 +310,38 @@ def compare_two_settings(model: str,
|
||||
env2: The second set of environment variables to pass to the API server.
|
||||
"""
|
||||
|
||||
compare_all_settings(
|
||||
model,
|
||||
[arg1, arg2],
|
||||
[env1, env2],
|
||||
method=method,
|
||||
max_wait_seconds=max_wait_seconds,
|
||||
)
|
||||
|
||||
|
||||
def compare_all_settings(model: str,
|
||||
all_args: List[List[str]],
|
||||
all_envs: List[Optional[Dict[str, str]]],
|
||||
*,
|
||||
method: Literal["generate", "encode"] = "generate",
|
||||
max_wait_seconds: Optional[float] = None) -> None:
|
||||
"""
|
||||
Launch API server with several different sets of arguments/environments
|
||||
and compare the results of the API calls with the first set of arguments.
|
||||
Args:
|
||||
model: The model to test.
|
||||
all_args: A list of argument lists to pass to the API server.
|
||||
all_envs: A list of environment dictionaries to pass to the API server.
|
||||
"""
|
||||
|
||||
trust_remote_code = False
|
||||
for args in (arg1, arg2):
|
||||
for args in all_args:
|
||||
if "--trust-remote-code" in args:
|
||||
trust_remote_code = True
|
||||
break
|
||||
|
||||
tokenizer_mode = "auto"
|
||||
for args in (arg1, arg2):
|
||||
for args in all_args:
|
||||
if "--tokenizer-mode" in args:
|
||||
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
|
||||
break
|
||||
@ -330,8 +354,10 @@ def compare_two_settings(model: str,
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
token_ids = tokenizer(prompt).input_ids
|
||||
results = []
|
||||
for args, env in ((arg1, env1), (arg2, env2)):
|
||||
ref_results: List = []
|
||||
for i, (args, env) in enumerate(zip(all_args, all_envs)):
|
||||
compare_results: List = []
|
||||
results = ref_results if i == 0 else compare_results
|
||||
with RemoteOpenAIServer(model,
|
||||
args,
|
||||
env_dict=env,
|
||||
@ -355,13 +381,20 @@ def compare_two_settings(model: str,
|
||||
else:
|
||||
assert_never(method)
|
||||
|
||||
n = len(results) // 2
|
||||
arg1_results = results[:n]
|
||||
arg2_results = results[n:]
|
||||
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
||||
assert arg1_result == arg2_result, (
|
||||
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
||||
f"{arg1_result=} != {arg2_result=}")
|
||||
if i > 0:
|
||||
# if any setting fails, raise an error early
|
||||
ref_args = all_args[0]
|
||||
ref_envs = all_envs[0]
|
||||
compare_args = all_args[i]
|
||||
compare_envs = all_envs[i]
|
||||
for ref_result, compare_result in zip(ref_results,
|
||||
compare_results):
|
||||
assert ref_result == compare_result, (
|
||||
f"Results for {model=} are not the same.\n"
|
||||
f"{ref_args=} {ref_envs=}\n"
|
||||
f"{compare_args=} {compare_envs=}\n"
|
||||
f"{ref_result=}\n"
|
||||
f"{compare_result=}\n")
|
||||
|
||||
|
||||
def init_test_distributed_environment(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user