[SpecDec] Remove Batch Expansion (2/3) (#9298)
This commit is contained in:
parent
ec10cb8511
commit
89feb4c84d
@ -1,3 +1,6 @@
|
|||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -10,31 +13,45 @@ from vllm.worker.worker import Worker
|
|||||||
from .utils import create_batch, create_worker
|
from .utils import create_batch, create_worker
|
||||||
|
|
||||||
|
|
||||||
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
|
def create_proposal(propose_lens: List[int], vocab_size: int,
|
||||||
device: str) -> SpeculativeProposals:
|
device: str) -> SpeculativeProposals:
|
||||||
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
|
batch_size = len(propose_lens)
|
||||||
|
max_propose_len = max(propose_lens)
|
||||||
|
proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size),
|
||||||
device=device)
|
device=device)
|
||||||
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
|
|
||||||
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
|
proposal_token_ids = torch.full((batch_size, max_propose_len),
|
||||||
|
fill_value=-1,
|
||||||
|
device=device)
|
||||||
|
for i in range(batch_size):
|
||||||
|
proposal_token_ids[i][:propose_lens[i]] = torch.argmax(
|
||||||
|
proposal_probs[i][:propose_lens[i]], dim=-1)
|
||||||
|
|
||||||
|
propose_lens = torch.tensor(propose_lens, device=device)
|
||||||
return SpeculativeProposals(proposal_token_ids, proposal_probs,
|
return SpeculativeProposals(proposal_token_ids, proposal_probs,
|
||||||
proposal_lens)
|
propose_lens)
|
||||||
|
|
||||||
|
|
||||||
def assert_score_equal(score1: SpeculativeScores,
|
def assert_score_equal(score1: SpeculativeScores,
|
||||||
score2: SpeculativeScores) -> None:
|
score2: SpeculativeScores) -> None:
|
||||||
assert torch.allclose(score1.probs, score2.probs)
|
assert torch.allclose(score1.probs, score2.probs)
|
||||||
assert torch.allclose(score1.logprobs, score2.logprobs)
|
assert torch.allclose(score1.logprobs, score2.logprobs)
|
||||||
assert torch.equal(score1.token_ids, score2.token_ids)
|
assert torch.equal(
|
||||||
|
score1.token_ids,
|
||||||
|
score2.token_ids), f"{score1.token_ids}, {score2.token_ids}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
|
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
|
||||||
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
|
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
|
||||||
@pytest.mark.parametrize('propose_len', [1, 3, 5])
|
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
|
||||||
|
@pytest.mark.parametrize('mixed_propose_len', [True])
|
||||||
@pytest.mark.parametrize('device', ['cuda'])
|
@pytest.mark.parametrize('device', ['cuda'])
|
||||||
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
|
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||||
device: str) -> None:
|
mixed_propose_len: bool, device: str) -> None:
|
||||||
"""
|
"""
|
||||||
Compare the batch expansion scorer and mqa scorer return the same score
|
Compare the batch expansion scorer and mqa scorer return the same score.
|
||||||
|
We test for both queries with the same propose length and different
|
||||||
|
propose length.
|
||||||
"""
|
"""
|
||||||
seed = 0
|
seed = 0
|
||||||
block_size = 32
|
block_size = 32
|
||||||
@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int,
|
|||||||
should_modify_greedy_probs_inplace = True
|
should_modify_greedy_probs_inplace = True
|
||||||
|
|
||||||
vocab_size = scorer_worker.vocab_size
|
vocab_size = scorer_worker.vocab_size
|
||||||
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
|
|
||||||
|
if not mixed_propose_len:
|
||||||
|
propose_lens = [max_propose_len] * batch_size
|
||||||
|
else:
|
||||||
|
non_zero_cnt = random.randint(0, batch_size)
|
||||||
|
propose_lens = [max_propose_len
|
||||||
|
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
|
||||||
|
random.shuffle(propose_lens)
|
||||||
|
|
||||||
|
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||||
propose_len,
|
max_propose_len,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_gpu_blocks=num_gpu_blocks)
|
num_gpu_blocks=num_gpu_blocks)
|
||||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||||
num_lookahead_slots=propose_len)
|
num_lookahead_slots=max_propose_len)
|
||||||
|
|
||||||
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
|
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
|
||||||
vocab_size)
|
vocab_size)
|
||||||
|
|||||||
@ -186,11 +186,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
|||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
use_cuda_graph: bool
|
use_cuda_graph: bool
|
||||||
|
|
||||||
# Number of query tokens for each request in the batch.
|
# Max number of query tokens for among request in the batch.
|
||||||
# Currently, we require that all requests have the same number of query
|
max_decode_query_len: Optional[int] = None
|
||||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
|
||||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
|
||||||
decode_query_len: Optional[int] = None
|
|
||||||
|
|
||||||
_cached_prefill_metadata: Optional[
|
_cached_prefill_metadata: Optional[
|
||||||
"BlocksparseFlashAttentionMetadata"] = None
|
"BlocksparseFlashAttentionMetadata"] = None
|
||||||
|
|||||||
@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
# Maximum query length in the batch.
|
# Maximum query length in the batch.
|
||||||
max_query_len: Optional[int]
|
max_query_len: Optional[int]
|
||||||
|
|
||||||
# Number of query tokens for each request in the batch.
|
# Max number of query tokens among request in the batch.
|
||||||
# Currently, we require that all requests have the same number of query
|
max_decode_query_len: Optional[int]
|
||||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
|
||||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
|
||||||
decode_query_len: Optional[int]
|
|
||||||
|
|
||||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
# requests only.
|
# requests only.
|
||||||
@ -173,9 +170,9 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||||
seq_lens=self.seq_lens[:self.num_prefills],
|
seq_lens=self.seq_lens[:self.num_prefills],
|
||||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||||
decode_query_len=0,
|
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
|
max_decode_query_len=0,
|
||||||
max_decode_seq_len=0,
|
max_decode_seq_len=0,
|
||||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||||
@ -202,12 +199,14 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||||
decode_query_len=self.decode_query_len,
|
max_decode_query_len=self.max_decode_query_len,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.max_decode_seq_len,
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
query_start_loc=None,
|
query_start_loc=self.query_start_loc[self.num_prefills:]
|
||||||
seq_start_loc=None,
|
if self.query_start_loc is not None else None,
|
||||||
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||||
|
if self.seq_start_loc is not None else None,
|
||||||
context_lens_tensor=None,
|
context_lens_tensor=None,
|
||||||
block_tables=self.block_tables[self.num_prefills:],
|
block_tables=self.block_tables[self.num_prefills:],
|
||||||
use_cuda_graph=self.use_cuda_graph,
|
use_cuda_graph=self.use_cuda_graph,
|
||||||
@ -413,9 +412,9 @@ class FlashAttentionMetadataBuilder(
|
|||||||
max_query_len = max(query_lens)
|
max_query_len = max(query_lens)
|
||||||
decode_query_lens = query_lens[self.num_prefills:]
|
decode_query_lens = query_lens[self.num_prefills:]
|
||||||
if len(decode_query_lens) > 0:
|
if len(decode_query_lens) > 0:
|
||||||
decode_query_len = max(decode_query_lens)
|
max_decode_query_len = max(decode_query_lens)
|
||||||
else:
|
else:
|
||||||
decode_query_len = 1
|
max_decode_query_len = 1
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
@ -468,7 +467,7 @@ class FlashAttentionMetadataBuilder(
|
|||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
decode_query_len=decode_query_len,
|
max_decode_query_len=max_decode_query_len,
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
@ -714,20 +713,37 @@ def unified_flash_attention(
|
|||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
_, num_head, head_dim = decode_query.shape
|
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||||
decode_query = decode_query.reshape(-1, decode_meta.decode_query_len,
|
# because different queries might have different lengths.
|
||||||
num_head, head_dim)
|
assert decode_meta.max_decode_query_len is not None
|
||||||
decode_output = flash_attn_with_kvcache(
|
if decode_meta.max_decode_query_len > 1:
|
||||||
q=decode_query,
|
decode_output = flash_attn_varlen_func(
|
||||||
k_cache=key_cache,
|
q=decode_query,
|
||||||
v_cache=value_cache,
|
k=key_cache,
|
||||||
block_table=decode_meta.block_tables,
|
v=value_cache,
|
||||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
cu_seqlens_q=decode_meta.query_start_loc,
|
||||||
softmax_scale=softmax_scale,
|
max_seqlen_q=decode_meta.max_decode_query_len,
|
||||||
causal=True,
|
cu_seqlens_k=decode_meta.seq_start_loc,
|
||||||
alibi_slopes=alibi_slopes,
|
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||||
softcap=logits_soft_cap,
|
softmax_scale=softmax_scale,
|
||||||
).squeeze(1)
|
causal=True,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
softcap=logits_soft_cap,
|
||||||
|
block_table=decode_meta.block_tables,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use flash_attn_with_kvcache for normal decoding.
|
||||||
|
decode_output = flash_attn_with_kvcache(
|
||||||
|
q=decode_query.unsqueeze(1),
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
block_table=decode_meta.block_tables,
|
||||||
|
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
softcap=logits_soft_cap,
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
if prefill_output is None:
|
if prefill_output is None:
|
||||||
assert decode_output is not None
|
assert decode_output is not None
|
||||||
@ -739,7 +755,6 @@ def unified_flash_attention(
|
|||||||
# Chunked prefill does not work with speculative decoding.
|
# Chunked prefill does not work with speculative decoding.
|
||||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
assert decode_meta.decode_query_len == 1
|
|
||||||
decode_output = decode_output.squeeze(1)
|
decode_output = decode_output.squeeze(1)
|
||||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|||||||
@ -121,11 +121,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# so far).
|
# so far).
|
||||||
context_lens_tensor: Optional[torch.Tensor]
|
context_lens_tensor: Optional[torch.Tensor]
|
||||||
|
|
||||||
# Number of query tokens for each request in the batch.
|
# Max number of query tokens among request in the batch.
|
||||||
# Currently, we require that all requests have the same number of query
|
max_decode_query_len: Optional[int] = None
|
||||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
|
||||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
|
||||||
decode_query_len: Optional[int] = None
|
|
||||||
|
|
||||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||||
|
|||||||
@ -313,7 +313,7 @@ class CommonAttentionState(AttentionState):
|
|||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||||
max_query_len=1,
|
max_query_len=1,
|
||||||
decode_query_len=1,
|
max_decode_query_len=1,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||||
query_start_loc=None,
|
query_start_loc=None,
|
||||||
|
|||||||
@ -118,11 +118,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
# Number of query tokens for each request in the batch.
|
# Max number of query tokens among request in the batch.
|
||||||
# Currently, we require that all requests have the same number of query
|
max_decode_query_len: Optional[int] = None
|
||||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
|
||||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
|
||||||
decode_query_len: Optional[int] = None
|
|
||||||
|
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
target_seq_id_start = max(
|
target_seq_id_start = max(
|
||||||
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
|
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
|
||||||
all_proposal_tokens = proposals.proposal_token_ids.tolist()
|
all_proposal_tokens = proposals.proposal_token_ids.tolist()
|
||||||
|
all_proposal_lengths = proposals.proposal_lens.tolist()
|
||||||
for i, seq_group_metadata in enumerate(
|
for i, seq_group_metadata in enumerate(
|
||||||
execute_model_req.seq_group_metadata_list):
|
execute_model_req.seq_group_metadata_list):
|
||||||
seq_data_dict = seq_group_metadata.seq_data
|
seq_data_dict = seq_group_metadata.seq_data
|
||||||
@ -27,7 +28,8 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
seq_data: SequenceData = seq_data_dict[seq_id]
|
seq_data: SequenceData = seq_data_dict[seq_id]
|
||||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||||
output_token_ids = seq_data.get_output_token_ids()
|
output_token_ids = seq_data.get_output_token_ids()
|
||||||
proposal_token_ids = all_proposal_tokens[i]
|
proposal_token_ids = all_proposal_tokens[
|
||||||
|
i][:all_proposal_lengths[i]]
|
||||||
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
|
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
|
||||||
|
|
||||||
target_seq_id = target_seq_id_start + i
|
target_seq_id = target_seq_id_start + i
|
||||||
@ -62,18 +64,42 @@ class MQAScorer(SpeculativeScorer):
|
|||||||
|
|
||||||
target_sampler_output = target_sampler_output[0]
|
target_sampler_output = target_sampler_output[0]
|
||||||
|
|
||||||
bs, k = proposals.proposal_token_ids.shape
|
k = execute_model_req.num_lookahead_slots
|
||||||
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)
|
bs = len(execute_model_req.seq_group_metadata_list)
|
||||||
|
target_token_ids = target_sampler_output.sampled_token_ids
|
||||||
all_probs = target_sampler_output.sampled_token_probs.reshape(
|
target_probs = target_sampler_output.sampled_token_probs
|
||||||
bs, k + 1, self._vocab_size)
|
target_logprobs = target_sampler_output.logprobs
|
||||||
all_logprobs = target_sampler_output.logprobs.reshape(
|
# If all requests have the same number of query tokens, we can avoid
|
||||||
bs, k + 1, self._vocab_size)
|
# the for loop to build output for better performance.
|
||||||
|
if min(all_proposal_lengths) == k:
|
||||||
|
bs, _ = proposals.proposal_token_ids.shape
|
||||||
|
all_tokens = target_token_ids.reshape(bs, k + 1)
|
||||||
|
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
||||||
|
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
|
||||||
|
else:
|
||||||
|
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
|
||||||
|
fill_value=-1)
|
||||||
|
all_probs = target_probs.new_zeros(*all_tokens.shape,
|
||||||
|
self._vocab_size)
|
||||||
|
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||||
|
fill_value=-float("inf"))
|
||||||
|
target_token_ids = target_token_ids.flatten()
|
||||||
|
start_loc = 0
|
||||||
|
for i, proposed_len in enumerate(all_proposal_lengths):
|
||||||
|
output_len = proposed_len + 1
|
||||||
|
end_loc = start_loc + output_len
|
||||||
|
all_tokens[
|
||||||
|
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||||
|
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||||
|
all_logprobs[
|
||||||
|
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||||
|
start_loc = end_loc
|
||||||
|
|
||||||
hidden_states = None
|
hidden_states = None
|
||||||
if target_sampler_output.hidden_states is not None:
|
if target_sampler_output.hidden_states is not None:
|
||||||
hidden_states = target_sampler_output.hidden_states.reshape(
|
hidden_states = target_sampler_output.hidden_states.reshape(
|
||||||
bs, (k + 1), -1)
|
bs, (k + 1), -1)
|
||||||
|
|
||||||
return SpeculativeScores(probs=all_probs,
|
return SpeculativeScores(probs=all_probs,
|
||||||
token_ids=all_tokens,
|
token_ids=all_tokens,
|
||||||
logprobs=all_logprobs,
|
logprobs=all_logprobs,
|
||||||
|
|||||||
@ -190,12 +190,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||||
"MQA is only available with flash attn backend.")
|
"MQA is only available with flash attn backend.")
|
||||||
|
|
||||||
if ngram_prompt_lookup_max > 0:
|
|
||||||
disable_mqa_scorer = True
|
|
||||||
logger.info(
|
|
||||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
|
||||||
"NGramWorker does not support MQA scorer.")
|
|
||||||
|
|
||||||
if "model_config" in draft_worker_kwargs and \
|
if "model_config" in draft_worker_kwargs and \
|
||||||
draft_worker_kwargs["model_config"].max_model_len < \
|
draft_worker_kwargs["model_config"].max_model_len < \
|
||||||
scorer_worker.model_config.max_model_len:
|
scorer_worker.model_config.max_model_len:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user