[misc] update utils to support comparing multiple settings (#9140)

This commit is contained in:
youkaichao 2024-10-07 19:51:49 -07:00 committed by GitHub
parent 8eeb857084
commit 04c12f8157
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -310,14 +310,38 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""
compare_all_settings(
model,
[arg1, arg2],
[env1, env2],
method=method,
max_wait_seconds=max_wait_seconds,
)
def compare_all_settings(model: str,
all_args: List[List[str]],
all_envs: List[Optional[Dict[str, str]]],
*,
method: Literal["generate", "encode"] = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with several different sets of arguments/environments
and compare the results of the API calls with the first set of arguments.
Args:
model: The model to test.
all_args: A list of argument lists to pass to the API server.
all_envs: A list of environment dictionaries to pass to the API server.
"""
trust_remote_code = False
for args in (arg1, arg2):
for args in all_args:
if "--trust-remote-code" in args:
trust_remote_code = True
break
tokenizer_mode = "auto"
for args in (arg1, arg2):
for args in all_args:
if "--tokenizer-mode" in args:
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
break
@ -330,8 +354,10 @@ def compare_two_settings(model: str,
prompt = "Hello, my name is"
token_ids = tokenizer(prompt).input_ids
results = []
for args, env in ((arg1, env1), (arg2, env2)):
ref_results: List = []
for i, (args, env) in enumerate(zip(all_args, all_envs)):
compare_results: List = []
results = ref_results if i == 0 else compare_results
with RemoteOpenAIServer(model,
args,
env_dict=env,
@ -355,13 +381,20 @@ def compare_two_settings(model: str,
else:
assert_never(method)
n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")
if i > 0:
# if any setting fails, raise an error early
ref_args = all_args[0]
ref_envs = all_envs[0]
compare_args = all_args[i]
compare_envs = all_envs[i]
for ref_result, compare_result in zip(ref_results,
compare_results):
assert ref_result == compare_result, (
f"Results for {model=} are not the same.\n"
f"{ref_args=} {ref_envs=}\n"
f"{compare_args=} {compare_envs=}\n"
f"{ref_result=}\n"
f"{compare_result=}\n")
def init_test_distributed_environment(