[misc] update utils to support comparing multiple settings (#9140)
This commit is contained in:
parent
8eeb857084
commit
04c12f8157
@ -310,14 +310,38 @@ def compare_two_settings(model: str,
|
|||||||
env2: The second set of environment variables to pass to the API server.
|
env2: The second set of environment variables to pass to the API server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
compare_all_settings(
|
||||||
|
model,
|
||||||
|
[arg1, arg2],
|
||||||
|
[env1, env2],
|
||||||
|
method=method,
|
||||||
|
max_wait_seconds=max_wait_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_all_settings(model: str,
|
||||||
|
all_args: List[List[str]],
|
||||||
|
all_envs: List[Optional[Dict[str, str]]],
|
||||||
|
*,
|
||||||
|
method: Literal["generate", "encode"] = "generate",
|
||||||
|
max_wait_seconds: Optional[float] = None) -> None:
|
||||||
|
"""
|
||||||
|
Launch API server with several different sets of arguments/environments
|
||||||
|
and compare the results of the API calls with the first set of arguments.
|
||||||
|
Args:
|
||||||
|
model: The model to test.
|
||||||
|
all_args: A list of argument lists to pass to the API server.
|
||||||
|
all_envs: A list of environment dictionaries to pass to the API server.
|
||||||
|
"""
|
||||||
|
|
||||||
trust_remote_code = False
|
trust_remote_code = False
|
||||||
for args in (arg1, arg2):
|
for args in all_args:
|
||||||
if "--trust-remote-code" in args:
|
if "--trust-remote-code" in args:
|
||||||
trust_remote_code = True
|
trust_remote_code = True
|
||||||
break
|
break
|
||||||
|
|
||||||
tokenizer_mode = "auto"
|
tokenizer_mode = "auto"
|
||||||
for args in (arg1, arg2):
|
for args in all_args:
|
||||||
if "--tokenizer-mode" in args:
|
if "--tokenizer-mode" in args:
|
||||||
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
|
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
|
||||||
break
|
break
|
||||||
@ -330,8 +354,10 @@ def compare_two_settings(model: str,
|
|||||||
|
|
||||||
prompt = "Hello, my name is"
|
prompt = "Hello, my name is"
|
||||||
token_ids = tokenizer(prompt).input_ids
|
token_ids = tokenizer(prompt).input_ids
|
||||||
results = []
|
ref_results: List = []
|
||||||
for args, env in ((arg1, env1), (arg2, env2)):
|
for i, (args, env) in enumerate(zip(all_args, all_envs)):
|
||||||
|
compare_results: List = []
|
||||||
|
results = ref_results if i == 0 else compare_results
|
||||||
with RemoteOpenAIServer(model,
|
with RemoteOpenAIServer(model,
|
||||||
args,
|
args,
|
||||||
env_dict=env,
|
env_dict=env,
|
||||||
@ -355,13 +381,20 @@ def compare_two_settings(model: str,
|
|||||||
else:
|
else:
|
||||||
assert_never(method)
|
assert_never(method)
|
||||||
|
|
||||||
n = len(results) // 2
|
if i > 0:
|
||||||
arg1_results = results[:n]
|
# if any setting fails, raise an error early
|
||||||
arg2_results = results[n:]
|
ref_args = all_args[0]
|
||||||
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
ref_envs = all_envs[0]
|
||||||
assert arg1_result == arg2_result, (
|
compare_args = all_args[i]
|
||||||
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
compare_envs = all_envs[i]
|
||||||
f"{arg1_result=} != {arg2_result=}")
|
for ref_result, compare_result in zip(ref_results,
|
||||||
|
compare_results):
|
||||||
|
assert ref_result == compare_result, (
|
||||||
|
f"Results for {model=} are not the same.\n"
|
||||||
|
f"{ref_args=} {ref_envs=}\n"
|
||||||
|
f"{compare_args=} {compare_envs=}\n"
|
||||||
|
f"{ref_result=}\n"
|
||||||
|
f"{compare_result=}\n")
|
||||||
|
|
||||||
|
|
||||||
def init_test_distributed_environment(
|
def init_test_distributed_environment(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user