Implement generation for GPT

This commit is contained in:
Tri Dao 2022-12-27 20:58:50 -08:00
parent 9d797d8848
commit 63670fd84a
5 changed files with 242 additions and 39 deletions

View File

@ -20,6 +20,7 @@ from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_sequence_parallel_params
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
@ -61,7 +62,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
if process_group is None else {})
parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
use_flash_attn=use_flash_attn,
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
@ -220,7 +221,7 @@ class GPTModel(GPTPreTrainedModel):
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
def forward(self, input_ids, position_ids=None):
def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
@ -238,12 +239,14 @@ class GPTModel(GPTPreTrainedModel):
residual_in_fp32=True
)
mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {})
if inference_params is not None:
mixer_kwargs['inference_params'] = inference_params
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
return hidden_states
class GPTLMHeadModel(GPTPreTrainedModel):
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
@ -267,8 +270,13 @@ class GPTLMHeadModel(GPTPreTrainedModel):
def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
def forward(self, input_ids, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids=position_ids)
def forward(self, input_ids, position_ids=None, inference_params=None):
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
hidden_states = self.transformer(input_ids, position_ids=position_ids,
inference_params=inference_params)
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits)

View File

@ -53,7 +53,7 @@ class FlashSelfAttention(nn.Module):
self.dropout_p = attention_dropout
self.triton = triton
def forward(self, qkv, cu_seqlens=None, max_seqlen=None):
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention.
Arguments
---------
@ -61,6 +61,7 @@ class FlashSelfAttention(nn.Module):
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
@ -71,6 +72,7 @@ class FlashSelfAttention(nn.Module):
"""
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if unpadded:
assert cu_seqlens.dtype == torch.int32
@ -78,13 +80,13 @@ class FlashSelfAttention(nn.Module):
assert isinstance(max_seqlen, int)
return flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
softmax_scale=self.softmax_scale, causal=causal
)
else:
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
# Triton version doesn't support dropout
if self.triton and (self.dropout_p == 0 or not self.training):
output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale)
output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
else:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_seqlen = seqlen
@ -92,7 +94,7 @@ class FlashSelfAttention(nn.Module):
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
@ -120,12 +122,14 @@ class FlashCrossAttention(nn.Module):
self.dropout_p = attention_dropout
self.triton = triton
def forward(self, q, kv, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None):
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
cu_seqlens_k=None, max_seqlen_k=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
@ -135,6 +139,7 @@ class FlashCrossAttention(nn.Module):
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if unpadded:
assert cu_seqlens.dtype == torch.int32
@ -147,14 +152,14 @@ class FlashCrossAttention(nn.Module):
return flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
softmax_scale=self.softmax_scale, causal=causal
)
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout
output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale)
output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
else:
q = rearrange(q, 'b s ... -> (b s) ...')
kv = rearrange(kv, 'b s ... -> (b s) ...')
@ -165,7 +170,7 @@ class FlashCrossAttention(nn.Module):
output = flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
@ -187,15 +192,17 @@ class SelfAttention(nn.Module):
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, key_padding_mask=None):
def forward(self, qkv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
causal = self.causal if causal is None else causal
q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
@ -205,7 +212,7 @@ class SelfAttention(nn.Module):
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if self.causal:
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
@ -233,16 +240,18 @@ class CrossAttention(nn.Module):
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, kv, key_padding_mask=None):
def forward(self, q, kv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size, seqlen_q = q.shape[0], q.shape[1]
causal = self.causal if causal is None else causal
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
k, v = kv.unbind(dim=2)
@ -254,7 +263,7 @@ class CrossAttention(nn.Module):
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
if self.causal:
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
@ -280,7 +289,7 @@ class MHA(nn.Module):
"""
def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0,
softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0,
rotary_emb_scale_base=0,
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
checkpointing=False, device=None, dtype=None) -> None:
@ -294,6 +303,7 @@ class MHA(nn.Module):
self.embed_dim = embed_dim
self.cross_attn = cross_attn
self.causal = causal
self.layer_idx = layer_idx
self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
@ -315,6 +325,8 @@ class MHA(nn.Module):
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
linear_resid_cls = (LinearResidual if not fused_bias_fc
else partial(FusedDense, return_residual=True))
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
if not self.cross_attn:
if not self.return_residual:
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
@ -323,7 +335,6 @@ class MHA(nn.Module):
if self.dwconv:
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
else:
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if not self.return_residual:
@ -335,14 +346,41 @@ class MHA(nn.Module):
groups=embed_dim)
self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2,
groups=2 * embed_dim)
inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, 1, nheads, head_dim)
"""
assert not self.dwconv, 'Generation does not support dwconv yet'
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference.
if self.layer_idx not in inference_params.key_value_memory_dict:
inference_kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device
)
inference_params.key_value_memory_dict[self.layer_idx] = inference_kv_cache
else:
inference_kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
assert batch_end <= inference_kv_cache.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert sequence_end <= inference_kv_cache.shape[1]
# Copy key and values.
inference_kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = inference_kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
**kwargs):
inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
@ -355,6 +393,8 @@ class MHA(nn.Module):
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
if cu_seqlens is not None:
assert max_seqlen is not None
@ -366,6 +406,10 @@ class MHA(nn.Module):
assert cu_seqlens is None
assert max_seqlen is None
assert not self.use_flash_attn
if inference_params is not None:
assert key_padding_mask is None
assert cu_seqlens is None and max_seqlen is None
assert not self.dwconv
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
@ -378,12 +422,22 @@ class MHA(nn.Module):
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
if inference_params is None:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = False if inference_params.sequence_len_offset == 0 else None
context = self.inner_cross_attn(q, kv, causal=causal)
else:
if not self.return_residual:
q = self.Wq(x)
@ -401,10 +455,14 @@ class MHA(nn.Module):
'b d s -> b s d').contiguous()
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
if not self.checkpointing:
context = self.inner_attn(q, kv, **kwargs)
if inference_params is None:
if not self.checkpointing:
context = self.inner_attn(q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
kv = self._update_kv_cache(kv)
context = self.inner_cross_attn(q, kv, causal=False)
out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x)

View File

@ -0,0 +1,64 @@
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from dataclasses import dataclass, field
import torch
from einops import rearrange
from transformers.generation import GreedySearchDecoderOnlyOutput
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_sequence_len: int
max_batch_size: int
sequence_len_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
def greedy_decode(input_ids, model, max_length):
"""Greedy decoding. This is a very simple implementation.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
scores = []
with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
scores.append(logits)
next_token = logits.argmax(dim=-1)
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device)
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
scores.append(logits)
next_token = logits.argmax(dim=-1)
sequences.append(next_token)
inference_params.sequence_len_offset += 1
if inference_params.sequence_len_offset >= max_length - 1:
break
return GreedySearchDecoderOnlyOutput(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
scores=tuple(scores)
)
class GenerationMixin:
def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False):
output = greedy_decode(input_ids, self, max_length)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences

View File

@ -23,16 +23,6 @@ def test_gpt2_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name)
model_hf = GPT2LMHeadModelHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf.load_state_dict(pretrained_state_dict, strict=False)
model_hf.cuda().to(dtype=dtype)
return model_hf
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name):

View File

@ -0,0 +1,83 @@
import re
import torch
import pytest
from einops import rearrange
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import greedy_decode
# TODO: test with rotary embedding
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode(model_name, optimized):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if optimized:
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
max_length = 30
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
out = model.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()