From 63670fd84a03bb448eb681806122c68e91e1b83e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 27 Dec 2022 20:58:50 -0800 Subject: [PATCH] Implement generation for GPT --- flash_attn/models/gpt.py | 18 +++-- flash_attn/modules/mha.py | 106 +++++++++++++++++++++------- flash_attn/utils/generation.py | 64 +++++++++++++++++ tests/models/test_gpt.py | 10 --- tests/models/test_gpt_generation.py | 83 ++++++++++++++++++++++ 5 files changed, 242 insertions(+), 39 deletions(-) create mode 100644 flash_attn/utils/generation.py create mode 100644 tests/models/test_gpt_generation.py diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 1ba58f6..0cb1b06 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -20,6 +20,7 @@ from flash_attn.modules.block import Block from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.utils.distributed import sync_sequence_parallel_params from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.generation import GenerationMixin try: from flash_attn.ops.fused_dense import ColumnParallelLinear @@ -61,7 +62,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt if process_group is None else {}) parallel_kwargs = {'process_group': process_group} if process_group is not None else {} mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop, - softmax_scale=softmax_scale, causal=True, + softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, use_flash_attn=use_flash_attn, **serial_kwargs, **parallel_kwargs, **factory_kwargs) @@ -220,7 +221,7 @@ class GPTModel(GPTPreTrainedModel): if self.process_group is not None: sync_sequence_parallel_params(self, self.process_group) - def forward(self, input_ids, position_ids=None): + def forward(self, input_ids, position_ids=None, inference_params=None): # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # dimensions so that we can split on it easily, in case of small batch size. # Only the attention layers need to know the seqlen. @@ -238,12 +239,14 @@ class GPTModel(GPTPreTrainedModel): residual_in_fp32=True ) mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {}) + if inference_params is not None: + mixer_kwargs['inference_params'] = inference_params for layer in self.layers: hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs) return hidden_states -class GPTLMHeadModel(GPTPreTrainedModel): +class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} @@ -267,8 +270,13 @@ class GPTLMHeadModel(GPTPreTrainedModel): def tie_weights(self): self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight - def forward(self, input_ids, position_ids=None): - hidden_states = self.transformer(input_ids, position_ids=position_ids) + def forward(self, input_ids, position_ids=None, inference_params=None): + """ + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + """ + hidden_states = self.transformer(input_ids, position_ids=position_ids, + inference_params=inference_params) lm_logits = self.lm_head(hidden_states) CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) return CausalLMOutput(logits=lm_logits) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 6bc3321..97755bc 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -53,7 +53,7 @@ class FlashSelfAttention(nn.Module): self.dropout_p = attention_dropout self.triton = triton - def forward(self, qkv, cu_seqlens=None, max_seqlen=None): + def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): """Implements the multihead softmax attention. Arguments --------- @@ -61,6 +61,7 @@ class FlashSelfAttention(nn.Module): If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). If cu_seqlens is not None and max_seqlen is not None, then qkv has shape (total, 3, H, D), where total is the sum of the sequence lengths in the batch. + causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. @@ -71,6 +72,7 @@ class FlashSelfAttention(nn.Module): """ assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda + causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None if unpadded: assert cu_seqlens.dtype == torch.int32 @@ -78,13 +80,13 @@ class FlashSelfAttention(nn.Module): assert isinstance(max_seqlen, int) return flash_attn_unpadded_qkvpacked_func( qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=self.causal + softmax_scale=self.softmax_scale, causal=causal ) else: batch_size, seqlen = qkv.shape[0], qkv.shape[1] # Triton version doesn't support dropout if self.triton and (self.dropout_p == 0 or not self.training): - output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale) + output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale) else: qkv = rearrange(qkv, 'b s ... -> (b s) ...') max_seqlen = seqlen @@ -92,7 +94,7 @@ class FlashSelfAttention(nn.Module): device=qkv.device) output = flash_attn_unpadded_qkvpacked_func( qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=self.causal + softmax_scale=self.softmax_scale, causal=causal ) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) return output @@ -120,12 +122,14 @@ class FlashCrossAttention(nn.Module): self.dropout_p = attention_dropout self.triton = triton - def forward(self, q, kv, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None): + def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, + cu_seqlens_k=None, max_seqlen_k=None): """Implements the multihead softmax attention. Arguments --------- q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H, D) + causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. max_seqlen: int. Maximum sequence length in the batch of q. @@ -135,6 +139,7 @@ class FlashCrossAttention(nn.Module): """ assert q.dtype in [torch.float16, torch.bfloat16] assert q.is_cuda and kv.is_cuda + causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None if unpadded: assert cu_seqlens.dtype == torch.int32 @@ -147,14 +152,14 @@ class FlashCrossAttention(nn.Module): return flash_attn_unpadded_kvpacked_func( q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k, self.dropout_p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=self.causal + softmax_scale=self.softmax_scale, causal=causal ) else: batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout - output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale) + output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale) else: q = rearrange(q, 'b s ... -> (b s) ...') kv = rearrange(kv, 'b s ... -> (b s) ...') @@ -165,7 +170,7 @@ class FlashCrossAttention(nn.Module): output = flash_attn_unpadded_kvpacked_func( q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, self.dropout_p if self.training else 0.0, - softmax_scale=self.softmax_scale, causal=self.causal + softmax_scale=self.softmax_scale, causal=causal ) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) return output @@ -187,15 +192,17 @@ class SelfAttention(nn.Module): self.softmax_scale = softmax_scale self.dropout_p = attention_dropout - def forward(self, qkv, key_padding_mask=None): + def forward(self, qkv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) + causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, S) """ batch_size, seqlen = qkv.shape[0], qkv.shape[1] + causal = self.causal if causal is None else causal q, k, v = qkv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) @@ -205,7 +212,7 @@ class SelfAttention(nn.Module): padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') - if self.causal: + if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) @@ -233,16 +240,18 @@ class CrossAttention(nn.Module): self.softmax_scale = softmax_scale self.dropout_p = attention_dropout - def forward(self, q, kv, key_padding_mask=None): + def forward(self, q, kv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H, D) + causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, Sk) """ batch_size, seqlen_q = q.shape[0], q.shape[1] + causal = self.causal if causal is None else causal seqlen_k = kv.shape[1] assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] k, v = kv.unbind(dim=2) @@ -254,7 +263,7 @@ class CrossAttention(nn.Module): padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') - if self.causal: + if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, @@ -280,7 +289,7 @@ class MHA(nn.Module): """ def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0, - softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0, + softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0, rotary_emb_scale_base=0, fused_bias_fc=False, use_flash_attn=False, return_residual=False, checkpointing=False, device=None, dtype=None) -> None: @@ -294,6 +303,7 @@ class MHA(nn.Module): self.embed_dim = embed_dim self.cross_attn = cross_attn self.causal = causal + self.layer_idx = layer_idx self.dwconv = dwconv self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn @@ -315,6 +325,8 @@ class MHA(nn.Module): linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_resid_cls = (LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)) + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention if not self.cross_attn: if not self.return_residual: self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) @@ -323,7 +335,6 @@ class MHA(nn.Module): if self.dwconv: self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, groups=3 * embed_dim) - inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention else: self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) if not self.return_residual: @@ -335,14 +346,41 @@ class MHA(nn.Module): groups=embed_dim) self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2, groups=2 * embed_dim) - inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout) # output projection always have the bias (for now) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, 1, nheads, head_dim) + """ + assert not self.dwconv, 'Generation does not support dwconv yet' + assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' + # Pre-allocate memory for key-values for inference. + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_kv_cache = torch.empty( + inference_params.max_batch_size, inference_params.max_sequence_len, 2, + self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device + ) + inference_params.key_value_memory_dict[self.layer_idx] = inference_kv_cache + else: + inference_kv_cache = inference_params.key_value_memory_dict[self.layer_idx] + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + assert batch_end <= inference_kv_cache.shape[0] + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert sequence_end <= inference_kv_cache.shape[1] + # Copy key and values. + inference_kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = inference_kv_cache[batch_start:batch_end, :sequence_end, ...] + return kv + def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, - **kwargs): + inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if @@ -355,6 +393,8 @@ class MHA(nn.Module): max_seqlen: int. Maximum sequence length in the batch. key_padding_mask: boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention. + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 """ if cu_seqlens is not None: assert max_seqlen is not None @@ -366,6 +406,10 @@ class MHA(nn.Module): assert cu_seqlens is None assert max_seqlen is None assert not self.use_flash_attn + if inference_params is not None: + assert key_padding_mask is None + assert cu_seqlens is None and max_seqlen is None + assert not self.dwconv kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs} if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs}) @@ -378,12 +422,22 @@ class MHA(nn.Module): qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], 'b d s -> b s d').contiguous() qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv) - if not self.checkpointing: - context = self.inner_attn(qkv, **kwargs) + if inference_params is None: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: - context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset) + q = qkv[:, :, 0] + kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = False if inference_params.sequence_len_offset == 0 else None + context = self.inner_cross_attn(q, kv, causal=causal) else: if not self.return_residual: q = self.Wq(x) @@ -401,10 +455,14 @@ class MHA(nn.Module): 'b d s -> b s d').contiguous() kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2], 'b d s -> b s d').contiguous() - if not self.checkpointing: - context = self.inner_attn(q, kv, **kwargs) + if inference_params is None: + if not self.checkpointing: + context = self.inner_attn(q, kv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs) else: - context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs) + kv = self._update_kv_cache(kv) + context = self.inner_cross_attn(q, kv, causal=False) out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) return out if not self.return_residual else (out, x) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py new file mode 100644 index 0000000..13f7fca --- /dev/null +++ b/flash_attn/utils/generation.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022, Tri Dao. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 +from dataclasses import dataclass, field +import torch + +from einops import rearrange + +from transformers.generation import GreedySearchDecoderOnlyOutput + + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + max_sequence_len: int + max_batch_size: int + sequence_len_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + + +def greedy_decode(input_ids, model, max_length): + """Greedy decoding. This is a very simple implementation. + We assume that all sequences in the same batch have the same length. + Arguments: + input_ids: (batch, seq_len) + max_length: int + Returns: GreedySearchDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) + scores = [] + with torch.inference_mode(): + logits = model(input_ids, inference_params=inference_params).logits[:, -1] + scores.append(logits) + next_token = logits.argmax(dim=-1) + sequences = [next_token] + inference_params.sequence_len_offset = seqlen_og + while True: + position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, + dtype=torch.long, device=input_ids.device) + logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, + inference_params=inference_params).logits[:, -1] + scores.append(logits) + next_token = logits.argmax(dim=-1) + sequences.append(next_token) + inference_params.sequence_len_offset += 1 + if inference_params.sequence_len_offset >= max_length - 1: + break + return GreedySearchDecoderOnlyOutput( + sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), + scores=tuple(scores) + ) + + +class GenerationMixin: + + def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False): + output = greedy_decode(input_ids, self, max_length) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 0520afa..2f5e777 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -23,16 +23,6 @@ def test_gpt2_state_dict(model_name): assert state_dict[k].shape == pretrained_state_dict[k].shape -def get_hf_models(model_name, config, dtype): - pretrained_state_dict = state_dict_from_pretrained(model_name) - model_hf = GPT2LMHeadModelHF(config) - # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias" - # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias. - model_hf.load_state_dict(pretrained_state_dict, strict=False) - model_hf.cuda().to(dtype=dtype) - return model_hf - - @pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"]) # @pytest.mark.parametrize('model_name', ["gpt2"]) def test_gpt2_non_optimized(model_name): diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py new file mode 100644 index 0000000..0aba58f --- /dev/null +++ b/tests/models/test_gpt_generation.py @@ -0,0 +1,83 @@ +import re + +import torch +import pytest + +from einops import rearrange + +from transformers import GPT2Config, GPT2Tokenizer +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF + +from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.models.gpt import remap_state_dict_gpt2 +from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.generation import greedy_decode + + +# TODO: test with rotary embedding +@pytest.mark.parametrize('optimized', [False, True]) +# @pytest.mark.parametrize('optimized', [False]) +@pytest.mark.parametrize('model_name', ["gpt2"]) +def test_greedy_decode(model_name, optimized): + """Check that our implementation of GPT2 generation matches the HF implementation: + the scores in fp16 should be around the same as the HF scores in fp16, when compared to + the HF scores in fp32. + """ + dtype = torch.float16 + rtol, atol = 3e-3, 3e-1 + config = GPT2Config.from_pretrained(model_name) + if optimized: + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_dense_gelu_dense = True + config.fused_dropout_add_ln = True + + model = GPTLMHeadModel.from_pretrained(model_name, config) + model = model.cuda().to(dtype=dtype) + + model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda() + model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype) + + model.eval() + model_ref.eval() + model_hf.eval() + + torch.manual_seed(0) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda() + max_length = 30 + + # Slow generation for reference + sequences = [] + scores = [] + cur_input_ids = input_ids + with torch.inference_mode(): + scores.append(model(cur_input_ids).logits[:, -1]) + sequences.append(scores[-1].argmax(dim=-1)) + for _ in range(input_ids.shape[1] + 1, max_length): + cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) + scores.append(model(cur_input_ids).logits[:, -1]) + sequences.append(scores[-1].argmax(dim=-1)) + sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) + scores = tuple(scores) + + out = model.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + + out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + + print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') + print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') + print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + + assert torch.all(out.sequences == sequences) + assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), + rtol=rtol, atol=atol) + assert torch.all(out.sequences == out_ref.sequences) + assert torch.all(out.sequences == out_hf.sequences) + + assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()