From e6a8026489cedee3d39f76e1fbfecda9f278e424 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 18 Sep 2023 16:08:44 -0700 Subject: [PATCH] [Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset --- flash_attn/modules/mha.py | 48 ++++++++---------- flash_attn/utils/generation.py | 93 +++++++++++++++++----------------- tests/models/test_gpt.py | 6 +-- 3 files changed, 72 insertions(+), 75 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 233f84f..4894dac 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx): if layer_idx not in inference_params.key_value_memory_dict: kv_cache = torch.empty( inference_params.max_batch_size, - inference_params.max_sequence_len, + inference_params.max_seqlen, 2, num_heads, head_dim, @@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx): # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.sequence_len_offset + sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) @@ -445,12 +445,12 @@ class MHA(nn.Module): q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ - assert inference_params is not None and inference_params.sequence_len_offset > 0 + assert inference_params is not None and inference_params.seqlen_offset > 0 assert self.use_flash_attn if self.rotary_emb_dim > 0: assert self.rotary_emb.scale is None, "This code path does not support xPos" self.rotary_emb._update_cos_sin_cache( - inference_params.max_sequence_len, device=q.device, dtype=q.dtype + inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: @@ -460,7 +460,7 @@ class MHA(nn.Module): cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) context = flash_attn_with_kvcache( q, @@ -480,11 +480,11 @@ class MHA(nn.Module): def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" if ( - inference_params.sequence_len_offset == 0 + inference_params.seqlen_offset == 0 or flash_attn_with_kvcache is None or not self.use_flash_attn ): - # TODO: this only uses sequence_len_offset and not lengths_per_sample. + # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) return self.inner_cross_attn(q, kv) else: @@ -493,7 +493,7 @@ class MHA(nn.Module): cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) return flash_attn_with_kvcache( q, @@ -561,12 +561,10 @@ class MHA(nn.Module): else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) ) - rotary_max_seqlen = ( - inference_params.max_sequence_len if inference_params is not None else None - ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None batch, seqlen = x.shape[:2] if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None @@ -581,7 +579,7 @@ class MHA(nn.Module): qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) if ( inference_params is None - or inference_params.sequence_len_offset == 0 + or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): @@ -632,7 +630,7 @@ class MHA(nn.Module): ).contiguous() if ( inference_params is None - or inference_params.sequence_len_offset == 0 + or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): @@ -772,12 +770,12 @@ class ParallelMHA(nn.Module): q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ - assert inference_params is not None and inference_params.sequence_len_offset > 0 + assert inference_params is not None and inference_params.seqlen_offset > 0 assert self.use_flash_attn if self.rotary_emb_dim > 0: assert self.rotary_emb.scale is None, "This code path does not support xPos" self.rotary_emb._update_cos_sin_cache( - inference_params.max_sequence_len, device=q.device, dtype=q.dtype + inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: @@ -787,7 +785,7 @@ class ParallelMHA(nn.Module): cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) context = flash_attn_with_kvcache( q, @@ -806,8 +804,8 @@ class ParallelMHA(nn.Module): def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" - if inference_params.sequence_len_offset == 0 or not self.use_flash_attn: - # TODO: this only uses sequence_len_offset and not lengths_per_sample. + if inference_params.seqlen_offset == 0 or not self.use_flash_attn: + # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) return self.inner_cross_attn(q, kv) else: @@ -816,7 +814,7 @@ class ParallelMHA(nn.Module): cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) context = flash_attn_with_kvcache( q, @@ -847,17 +845,15 @@ class ParallelMHA(nn.Module): else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None - else inference_params.sequence_len_offset + else inference_params.seqlen_offset ) ) - rotary_max_seqlen = ( - inference_params.max_sequence_len if inference_params is not None else None - ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None if self.num_heads_kv == self.num_heads: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) if ( inference_params is None - or inference_params.sequence_len_offset == 0 + or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): @@ -892,7 +888,7 @@ class ParallelMHA(nn.Module): ) if ( inference_params is None - or inference_params.sequence_len_offset == 0 + or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 6e671ce..7afcd92 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -20,13 +20,20 @@ class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" - max_sequence_len: int + max_seqlen: int max_batch_size: int - sequence_len_offset: int = 0 + seqlen_offset: int = 0 batch_size_offset: int = 0 key_value_memory_dict: dict = field(default_factory=dict) lengths_per_sample: Optional[Tensor] = None + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 @@ -127,19 +134,16 @@ def decode( tensor_parallel=tensor_parallel, ) inference_params = model._decoding_cache.inference_params - inference_params.max_sequence_len = max_length - inference_params.max_batch_size = batch_size - inference_params.sequence_len_offset = 0 - inference_params.lengths_per_sample.zero_() + inference_params.reset(max_length, batch_size) else: - inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def get_logits(input_ids, inference_params): - decoding = inference_params.sequence_len_offset > 0 + decoding = inference_params.seqlen_offset > 0 if decoding: position_ids = torch.full( (batch_size, 1), - inference_params.sequence_len_offset, + inference_params.seqlen_offset, dtype=torch.long, device=input_ids.device, ) @@ -154,24 +158,24 @@ def decode( ).logits.squeeze(dim=1) else: logits = model._decoding_cache.run( - input_ids, position_ids, inference_params.sequence_len_offset + input_ids, position_ids, inference_params.seqlen_offset ).clone() return logits[..., :vocab_size] if vocab_size is not None else logits def sample_tokens(logits, inference_params): - if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset: + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) else: - token = teacher_outputs[:, inference_params.sequence_len_offset] + token = teacher_outputs[:, inference_params.seqlen_offset] # return rearrange(token, "b -> b 1") return token.unsqueeze(1) def should_stop(current_token, inference_params): - if inference_params.sequence_len_offset == 0: + if inference_params.seqlen_offset == 0: return False if eos_token_id is not None and (current_token == eos_token_id).all(): return True - if inference_params.sequence_len_offset >= max_length - 1: + if inference_params.seqlen_offset >= max_length - 1: return True return False @@ -185,7 +189,7 @@ def decode( scores, sequences = [], [input_ids] while not should_stop(sequences[-1], inference_params): scores.append(get_logits(sequences[-1], inference_params)) - inference_params.sequence_len_offset += sequences[-1].shape[1] + inference_params.seqlen_offset += sequences[-1].shape[1] sequences.append(sample_tokens(scores[-1], inference_params)) if enable_timing: end.record() @@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t return tokens, first_rejected_idx + 1 +@torch.inference_mode() def decode_speculative( input_ids, model, @@ -303,15 +308,11 @@ def decode_speculative( tensor_parallel=tensor_parallel, ) inference_params_draft = model_draft._decoding_cache.inference_params - inference_params_draft.max_sequence_len = max_length - inference_params_draft.max_batch_size = batch_size - inference_params_draft.sequence_len_offset = 0 - inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) + inference_params_draft.reset(max_length, batch_size) + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) else: - inference_params_draft = InferenceParams( - max_sequence_len=max_length, max_batch_size=batch_size - ) - inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) + inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False): if not cg: @@ -323,7 +324,7 @@ def decode_speculative( ).logits.squeeze(dim=1) else: return model._decoding_cache.run( - input_ids, position_ids, inference_params.sequence_len_offset + input_ids, position_ids, inference_params.seqlen_offset ).clone() logits_postprocess_fn = ( @@ -365,13 +366,13 @@ def decode_speculative( assert seqlen == 1 position_ids = repeat( torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - + inference_params.sequence_len_offset, + + inference_params.seqlen_offset, "s -> b s", b=batch_size, ) # position_ids = torch.full( # (batch_size, 1), - # inference_params.sequence_len_offset, + # inference_params.seqlen_offset, # dtype=torch.long, # device=input_ids.device, # ) @@ -380,7 +381,7 @@ def decode_speculative( logits = logits_postprocess_fn( logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg) ) - inference_params.sequence_len_offset += input_ids.shape[1] + inference_params.seqlen_offset += input_ids.shape[1] scores = [logits] next_token = sample_fn(logits) sequences.append(next_token) @@ -388,7 +389,7 @@ def decode_speculative( if i < num_tokens - 1 or last_token_logits: position_ids = torch.full( (batch_size, 1), - inference_params_draft.sequence_len_offset, + inference_params_draft.seqlen_offset, dtype=torch.long, device=input_ids.device, ) @@ -401,7 +402,7 @@ def decode_speculative( cg=cg, ) ) - inference_params.sequence_len_offset += 1 + inference_params.seqlen_offset += 1 scores.append(logits) if i < num_tokens - 1: next_token = sample_fn(logits) @@ -476,8 +477,8 @@ def decode_speculative( scores.append(logits[:1, : num_generated_tokens[0]]) # Note that @model has not evaluated the last sampled token yet, so we'll need to pass # that in the next time we call @model. - inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1 - inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset + inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1 + inference_params_draft.seqlen_offset = inference_params.seqlen_offset if debug: cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) scores_ref = model( @@ -486,10 +487,10 @@ def decode_speculative( print((scores[-1] - scores_ref[:, :-1]).abs().max()) while True: - # sequence_len_offset is total length generated - 1 - if inference_params.sequence_len_offset >= max_length - 1: + # seqlen_offset is total length generated - 1 + if inference_params.seqlen_offset >= max_length - 1: break - if inference_params.sequence_len_offset >= max_length - 2: + if inference_params.seqlen_offset >= max_length - 2: # Don't do speculative sampling, just sample 1 token from the model tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) sequences.append(tokens) @@ -497,7 +498,7 @@ def decode_speculative( break # Sample from draft model n_spec_tokens = min( - speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2 + speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 ) tokens_draft, scores_draft = sample_tokens_draft( sequences[-1][:, -1:], num_tokens=n_spec_tokens @@ -510,9 +511,9 @@ def decode_speculative( # Evaluate the draft tokens with the model position_ids = repeat( torch.arange( - inference_params.sequence_len_offset, + inference_params.seqlen_offset, # 1 extra token from last time that hasn't been passed through model - inference_params.sequence_len_offset + n_spec_tokens + 1, + inference_params.seqlen_offset + n_spec_tokens + 1, dtype=torch.long, device=input_ids.device, ), @@ -525,7 +526,7 @@ def decode_speculative( inference_params=inference_params, ).logits # (batch, n_spec_tokens, vocab_size) logits = logits_postprocess_fn(logits) - inference_params.sequence_len_offset += 1 + inference_params.seqlen_offset += 1 if debug: logits_ref = model( torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 @@ -539,8 +540,8 @@ def decode_speculative( print(num_generated_tokens) sequences.append(tokens[:1, : num_generated_tokens[0]]) scores.append(logits[:1, : num_generated_tokens[0]]) - inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1 - inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset + inference_params.seqlen_offset += num_generated_tokens[0].item() - 1 + inference_params_draft.seqlen_offset = inference_params.seqlen_offset # breakpoint() if debug: cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) @@ -679,9 +680,9 @@ def update_graph_cache( ) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) cache.inference_params = InferenceParams( - max_sequence_len=max_seqlen, + max_seqlen=max_seqlen, max_batch_size=batch_size, - sequence_len_offset=seqlen_og, + seqlen_offset=seqlen_og, key_value_memory_dict=inf_cache, lengths_per_sample=lengths_per_sample, ) @@ -705,7 +706,7 @@ def update_graph_cache( ) cache.run = dispatch - cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing + cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing return cache @@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, device = next(iter(model.parameters())).device input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) - sequence_len_offset_og = inference_params.sequence_len_offset + seqlen_offset_og = inference_params.seqlen_offset # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample. - inference_params.sequence_len_offset = max_seqlen - 1 + inference_params.seqlen_offset = max_seqlen - 1 inference_params.lengths_per_sample[:] = max_seqlen - 1 # Warmup before capture @@ -755,5 +756,5 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, graph.replay() return logits.clone() - inference_params.sequence_len_offset = sequence_len_offset_og + inference_params.seqlen_offset = seqlen_offset_og return run diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 2e18451..09c6556 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized): logits_ref = model(input_ids).logits # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits - inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1) + inference_params = InferenceParams(max_seqlen=20, max_batch_size=1) logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits - inference_params.sequence_len_offset += 10 + inference_params.seqlen_offset += 10 position_ids = torch.arange(10, 14, dtype=torch.long, device=device) logits_1014 = model( input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params ).logits - inference_params.sequence_len_offset += 4 + inference_params.seqlen_offset += 4 position_ids = torch.arange(14, 20, dtype=torch.long, device=device) logits_1420 = model( input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params