diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 3edcf63..77744b4 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -426,20 +426,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): if self.process_group is not None: sync_shared_params(self, self.process_group) - def forward(self, input_ids, position_ids=None, inference_params=None): + def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False): """ inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + last_token_only: whether to return the logit for the last token only, + of shape (batch_size, vocab_size) """ hidden_states = self.transformer(input_ids, position_ids=position_ids, inference_params=inference_params) + if last_token_only: + hidden_states = hidden_states[:, -1] if self.project_out is not None: hidden_states = self.project_out(hidden_states) lm_logits = self.lm_head(hidden_states) # During inference, we want the full logit for sampling if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) - lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0]) + lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0]) CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) return CausalLMOutput(logits=lm_logits) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index ce4367b..6527534 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -112,7 +112,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, torch.distributed.barrier() torch.cuda.synchronize() start = time.time() - logits = model(input_ids, inference_params=inference_params).logits[:, -1] + logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits if vocab_size is not None: logits = logits[..., :vocab_size] scores.append(logits if not cg else logits.clone()) @@ -127,7 +127,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, dtype=torch.long, device=input_ids.device) if not cg: logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, - inference_params=inference_params).logits[:, -1] + inference_params=inference_params, last_token_only=True).logits else: logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids, inference_params.sequence_len_offset) @@ -269,8 +269,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(n_warmups): - logits = model(input_ids, position_ids=position_ids, - inference_params=inference_params).logits[:, -1] + logits = model(input_ids, position_ids=position_ids, inference_params=inference_params, + last_token_only=True).logits s.synchronize() # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, # which requires that graph launch and non-captured launch to not overlap (I think, @@ -282,8 +282,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, # To allow capture, automatically sets a side stream as the current stream in the context graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=mempool): - logits = model(input_ids, position_ids=position_ids, - inference_params=inference_params).logits[:, -1] + logits = model(input_ids, position_ids=position_ids, inference_params=inference_params, + last_token_only=True).logits def run(new_input_ids, new_position_ids, seqlen): inference_params.lengths_per_sample[:] = seqlen