[SpecDec] Remove Batch Expansion (2/3) (#9298)

This commit is contained in:
Lily Liu 2024-10-11 22:13:37 -07:00 committed by GitHub
parent ec10cb8511
commit 89feb4c84d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 122 additions and 70 deletions

View File

@ -1,3 +1,6 @@
import random
from typing import List
import pytest import pytest
import torch import torch
@ -10,31 +13,45 @@ from vllm.worker.worker import Worker
from .utils import create_batch, create_worker from .utils import create_batch, create_worker
def create_proposal(batch_size: int, propose_len: int, vocab_size: int, def create_proposal(propose_lens: List[int], vocab_size: int,
device: str) -> SpeculativeProposals: device: str) -> SpeculativeProposals:
proposal_probs = torch.rand((batch_size, propose_len, vocab_size), batch_size = len(propose_lens)
max_propose_len = max(propose_lens)
proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size),
device=device) device=device)
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
proposal_lens = torch.tensor([propose_len] * batch_size, device=device) proposal_token_ids = torch.full((batch_size, max_propose_len),
fill_value=-1,
device=device)
for i in range(batch_size):
proposal_token_ids[i][:propose_lens[i]] = torch.argmax(
proposal_probs[i][:propose_lens[i]], dim=-1)
propose_lens = torch.tensor(propose_lens, device=device)
return SpeculativeProposals(proposal_token_ids, proposal_probs, return SpeculativeProposals(proposal_token_ids, proposal_probs,
proposal_lens) propose_lens)
def assert_score_equal(score1: SpeculativeScores, def assert_score_equal(score1: SpeculativeScores,
score2: SpeculativeScores) -> None: score2: SpeculativeScores) -> None:
assert torch.allclose(score1.probs, score2.probs) assert torch.allclose(score1.probs, score2.probs)
assert torch.allclose(score1.logprobs, score2.logprobs) assert torch.allclose(score1.logprobs, score2.logprobs)
assert torch.equal(score1.token_ids, score2.token_ids) assert torch.equal(
score1.token_ids,
score2.token_ids), f"{score1.token_ids}, {score2.token_ids}"
@pytest.mark.parametrize('model_name', ['facebook/opt-125m']) @pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) @pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
@pytest.mark.parametrize('propose_len', [1, 3, 5]) @pytest.mark.parametrize('max_propose_len', [1, 3, 5])
@pytest.mark.parametrize('mixed_propose_len', [True])
@pytest.mark.parametrize('device', ['cuda']) @pytest.mark.parametrize('device', ['cuda'])
def test_scoroer(model_name: str, batch_size: int, propose_len: int, def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
device: str) -> None: mixed_propose_len: bool, device: str) -> None:
""" """
Compare the batch expansion scorer and mqa scorer return the same score Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length.
""" """
seed = 0 seed = 0
block_size = 32 block_size = 32
@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int,
should_modify_greedy_probs_inplace = True should_modify_greedy_probs_inplace = True
vocab_size = scorer_worker.vocab_size vocab_size = scorer_worker.vocab_size
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
if not mixed_propose_len:
propose_lens = [max_propose_len] * batch_size
else:
non_zero_cnt = random.randint(0, batch_size)
propose_lens = [max_propose_len
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
random.shuffle(propose_lens)
proposals = create_proposal(propose_lens, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size, seq_group_metadatalist, _, _ = create_batch(batch_size,
propose_len, max_propose_len,
block_size=block_size, block_size=block_size,
num_gpu_blocks=num_gpu_blocks) num_gpu_blocks=num_gpu_blocks)
requests = ExecuteModelRequest(seq_group_metadatalist, requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=propose_len) num_lookahead_slots=max_propose_len)
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
vocab_size) vocab_size)

View File

@ -186,11 +186,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# Number of query tokens for each request in the batch. # Max number of query tokens for among request in the batch.
# Currently, we require that all requests have the same number of query max_decode_query_len: Optional[int] = None
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional[ _cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None "BlocksparseFlashAttentionMetadata"] = None

View File

@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
# Maximum query length in the batch. # Maximum query length in the batch.
max_query_len: Optional[int] max_query_len: Optional[int]
# Number of query tokens for each request in the batch. # Max number of query tokens among request in the batch.
# Currently, we require that all requests have the same number of query max_decode_query_len: Optional[int]
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding # Maximum sequence length among prefill batch. 0 if there are decoding
# requests only. # requests only.
@ -173,9 +170,9 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
decode_query_len=0,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0, max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1], query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
@ -202,12 +199,14 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
decode_query_len=self.decode_query_len, max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None, query_start_loc=self.query_start_loc[self.num_prefills:]
seq_start_loc=None, if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None, context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:], block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
@ -413,9 +412,9 @@ class FlashAttentionMetadataBuilder(
max_query_len = max(query_lens) max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:] decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0: if len(decode_query_lens) > 0:
decode_query_len = max(decode_query_lens) max_decode_query_len = max(decode_query_lens)
else: else:
decode_query_len = 1 max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
@ -468,7 +467,7 @@ class FlashAttentionMetadataBuilder(
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len, max_query_len=max_query_len,
decode_query_len=decode_query_len, max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
@ -714,20 +713,37 @@ def unified_flash_attention(
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
_, num_head, head_dim = decode_query.shape # Use flash_attn_varlen_func kernel for speculative decoding
decode_query = decode_query.reshape(-1, decode_meta.decode_query_len, # because different queries might have different lengths.
num_head, head_dim) assert decode_meta.max_decode_query_len is not None
decode_output = flash_attn_with_kvcache( if decode_meta.max_decode_query_len > 1:
q=decode_query, decode_output = flash_attn_varlen_func(
k_cache=key_cache, q=decode_query,
v_cache=value_cache, k=key_cache,
block_table=decode_meta.block_tables, v=value_cache,
cache_seqlens=decode_meta.seq_lens_tensor, cu_seqlens_q=decode_meta.query_start_loc,
softmax_scale=softmax_scale, max_seqlen_q=decode_meta.max_decode_query_len,
causal=True, cu_seqlens_k=decode_meta.seq_start_loc,
alibi_slopes=alibi_slopes, max_seqlen_k=decode_meta.max_decode_seq_len,
softcap=logits_soft_cap, softmax_scale=softmax_scale,
).squeeze(1) causal=True,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
if prefill_output is None: if prefill_output is None:
assert decode_output is not None assert decode_output is not None
@ -739,7 +755,6 @@ def unified_flash_attention(
# Chunked prefill does not work with speculative decoding. # Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill. # Therefore, the query length for decode should be 1 in chunked prefill.
assert decode_meta is not None assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1) decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0) output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)

View File

@ -121,11 +121,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# so far). # so far).
context_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor]
# Number of query tokens for each request in the batch. # Max number of query tokens among request in the batch.
# Currently, we require that all requests have the same number of query max_decode_query_len: Optional[int] = None
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

View File

@ -313,7 +313,7 @@ class CommonAttentionState(AttentionState):
seq_lens=None, seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size], seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1, max_query_len=1,
decode_query_len=1, max_decode_query_len=1,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture, max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None, query_start_loc=None,

View File

@ -118,11 +118,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum query length in the batch. None for decoding. # Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None max_query_len: Optional[int] = None
# Number of query tokens for each request in the batch. # Max number of query tokens among request in the batch.
# Currently, we require that all requests have the same number of query max_decode_query_len: Optional[int] = None
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in # (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length # the batch, used to index into subquery. E.g., if the subquery length

View File

@ -18,6 +18,7 @@ class MQAScorer(SpeculativeScorer):
target_seq_id_start = max( target_seq_id_start = max(
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
all_proposal_tokens = proposals.proposal_token_ids.tolist() all_proposal_tokens = proposals.proposal_token_ids.tolist()
all_proposal_lengths = proposals.proposal_lens.tolist()
for i, seq_group_metadata in enumerate( for i, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list): execute_model_req.seq_group_metadata_list):
seq_data_dict = seq_group_metadata.seq_data seq_data_dict = seq_group_metadata.seq_data
@ -27,7 +28,8 @@ class MQAScorer(SpeculativeScorer):
seq_data: SequenceData = seq_data_dict[seq_id] seq_data: SequenceData = seq_data_dict[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids() prompt_token_ids = seq_data.get_prompt_token_ids()
output_token_ids = seq_data.get_output_token_ids() output_token_ids = seq_data.get_output_token_ids()
proposal_token_ids = all_proposal_tokens[i] proposal_token_ids = all_proposal_tokens[
i][:all_proposal_lengths[i]]
new_output_token_ids = [*output_token_ids, *proposal_token_ids] new_output_token_ids = [*output_token_ids, *proposal_token_ids]
target_seq_id = target_seq_id_start + i target_seq_id = target_seq_id_start + i
@ -62,18 +64,42 @@ class MQAScorer(SpeculativeScorer):
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
bs, k = proposals.proposal_token_ids.shape k = execute_model_req.num_lookahead_slots
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1) bs = len(execute_model_req.seq_group_metadata_list)
target_token_ids = target_sampler_output.sampled_token_ids
all_probs = target_sampler_output.sampled_token_probs.reshape( target_probs = target_sampler_output.sampled_token_probs
bs, k + 1, self._vocab_size) target_logprobs = target_sampler_output.logprobs
all_logprobs = target_sampler_output.logprobs.reshape( # If all requests have the same number of query tokens, we can avoid
bs, k + 1, self._vocab_size) # the for loop to build output for better performance.
if min(all_proposal_lengths) == k:
bs, _ = proposals.proposal_token_ids.shape
all_tokens = target_token_ids.reshape(bs, k + 1)
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
else:
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape,
self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
target_token_ids = target_token_ids.flatten()
start_loc = 0
for i, proposed_len in enumerate(all_proposal_lengths):
output_len = proposed_len + 1
end_loc = start_loc + output_len
all_tokens[
i, :output_len] = target_token_ids[start_loc:end_loc]
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
all_logprobs[
i, :output_len] = target_logprobs[start_loc:end_loc]
start_loc = end_loc
hidden_states = None hidden_states = None
if target_sampler_output.hidden_states is not None: if target_sampler_output.hidden_states is not None:
hidden_states = target_sampler_output.hidden_states.reshape( hidden_states = target_sampler_output.hidden_states.reshape(
bs, (k + 1), -1) bs, (k + 1), -1)
return SpeculativeScores(probs=all_probs, return SpeculativeScores(probs=all_probs,
token_ids=all_tokens, token_ids=all_tokens,
logprobs=all_logprobs, logprobs=all_logprobs,

View File

@ -190,12 +190,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the " "[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.") "MQA is only available with flash attn backend.")
if ngram_prompt_lookup_max > 0:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"NGramWorker does not support MQA scorer.")
if "model_config" in draft_worker_kwargs and \ if "model_config" in draft_worker_kwargs and \
draft_worker_kwargs["model_config"].max_model_len < \ draft_worker_kwargs["model_config"].max_model_len < \
scorer_worker.model_config.max_model_len: scorer_worker.model_config.max_model_len: