[Speculative decoding] CUDA graph support (#4295)
Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
parent
706588a77d
commit
2e7796f2cf
@ -611,3 +611,40 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Verify equality when cuda graphs allowed.
|
||||
"enforce_eager": False,
|
||||
"model": "JackFram/llama-68m",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Identical models.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("output_len", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
|
||||
batch_size, output_len):
|
||||
"""Verify spec decode equality when cuda graphs are enabled.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(
|
||||
baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user