Implement generation for GPT
This commit is contained in:
parent
9d797d8848
commit
63670fd84a
@ -20,6 +20,7 @@ from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_sequence_parallel_params
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear
|
||||
@ -61,7 +62,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
if process_group is None else {})
|
||||
parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
|
||||
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
|
||||
softmax_scale=softmax_scale, causal=True,
|
||||
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
|
||||
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
use_flash_attn=use_flash_attn,
|
||||
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
|
||||
@ -220,7 +221,7 @@ class GPTModel(GPTPreTrainedModel):
|
||||
if self.process_group is not None:
|
||||
sync_sequence_parallel_params(self, self.process_group)
|
||||
|
||||
def forward(self, input_ids, position_ids=None):
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
|
||||
# dimensions so that we can split on it easily, in case of small batch size.
|
||||
# Only the attention layers need to know the seqlen.
|
||||
@ -238,12 +239,14 @@ class GPTModel(GPTPreTrainedModel):
|
||||
residual_in_fp32=True
|
||||
)
|
||||
mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {})
|
||||
if inference_params is not None:
|
||||
mixer_kwargs['inference_params'] = inference_params
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTLMHeadModel(GPTPreTrainedModel):
|
||||
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
|
||||
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
@ -267,8 +270,13 @@ class GPTLMHeadModel(GPTPreTrainedModel):
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
||||
|
||||
def forward(self, input_ids, position_ids=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids=position_ids)
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||
"""
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
"""
|
||||
hidden_states = self.transformer(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
@ -53,7 +53,7 @@ class FlashSelfAttention(nn.Module):
|
||||
self.dropout_p = attention_dropout
|
||||
self.triton = triton
|
||||
|
||||
def forward(self, qkv, cu_seqlens=None, max_seqlen=None):
|
||||
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
@ -61,6 +61,7 @@ class FlashSelfAttention(nn.Module):
|
||||
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
|
||||
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
|
||||
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
|
||||
causal: if passed, will override self.causal
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into qkv.
|
||||
max_seqlen: int. Maximum sequence length in the batch.
|
||||
@ -71,6 +72,7 @@ class FlashSelfAttention(nn.Module):
|
||||
"""
|
||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
||||
assert qkv.is_cuda
|
||||
causal = self.causal if causal is None else causal
|
||||
unpadded = cu_seqlens is not None
|
||||
if unpadded:
|
||||
assert cu_seqlens.dtype == torch.int32
|
||||
@ -78,13 +80,13 @@ class FlashSelfAttention(nn.Module):
|
||||
assert isinstance(max_seqlen, int)
|
||||
return flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=self.causal
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
else:
|
||||
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
||||
# Triton version doesn't support dropout
|
||||
if self.triton and (self.dropout_p == 0 or not self.training):
|
||||
output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale)
|
||||
output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
|
||||
else:
|
||||
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
||||
max_seqlen = seqlen
|
||||
@ -92,7 +94,7 @@ class FlashSelfAttention(nn.Module):
|
||||
device=qkv.device)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=self.causal
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
return output
|
||||
@ -120,12 +122,14 @@ class FlashCrossAttention(nn.Module):
|
||||
self.dropout_p = attention_dropout
|
||||
self.triton = triton
|
||||
|
||||
def forward(self, q, kv, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None):
|
||||
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
|
||||
cu_seqlens_k=None, max_seqlen_k=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q: The tensor containing the query. (B, Sq, H, D)
|
||||
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
|
||||
causal: if passed, will override self.causal
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
max_seqlen: int. Maximum sequence length in the batch of q.
|
||||
@ -135,6 +139,7 @@ class FlashCrossAttention(nn.Module):
|
||||
"""
|
||||
assert q.dtype in [torch.float16, torch.bfloat16]
|
||||
assert q.is_cuda and kv.is_cuda
|
||||
causal = self.causal if causal is None else causal
|
||||
unpadded = cu_seqlens is not None
|
||||
if unpadded:
|
||||
assert cu_seqlens.dtype == torch.int32
|
||||
@ -147,14 +152,14 @@ class FlashCrossAttention(nn.Module):
|
||||
return flash_attn_unpadded_kvpacked_func(
|
||||
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=self.causal
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
else:
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = kv.shape[1]
|
||||
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
||||
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout
|
||||
output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale)
|
||||
output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
|
||||
else:
|
||||
q = rearrange(q, 'b s ... -> (b s) ...')
|
||||
kv = rearrange(kv, 'b s ... -> (b s) ...')
|
||||
@ -165,7 +170,7 @@ class FlashCrossAttention(nn.Module):
|
||||
output = flash_attn_unpadded_kvpacked_func(
|
||||
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=self.causal
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
return output
|
||||
@ -187,15 +192,17 @@ class SelfAttention(nn.Module):
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(self, qkv, key_padding_mask=None):
|
||||
def forward(self, qkv, causal=None, key_padding_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
||||
causal: if passed, will override self.causal
|
||||
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
||||
False means to mask out. (B, S)
|
||||
"""
|
||||
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
||||
causal = self.causal if causal is None else causal
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
||||
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
||||
@ -205,7 +212,7 @@ class SelfAttention(nn.Module):
|
||||
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
|
||||
if self.causal:
|
||||
if causal:
|
||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||
# So we have to construct the mask in float
|
||||
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
||||
@ -233,16 +240,18 @@ class CrossAttention(nn.Module):
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(self, q, kv, key_padding_mask=None):
|
||||
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q: The tensor containing the query. (B, Sq, H, D)
|
||||
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
|
||||
causal: if passed, will override self.causal
|
||||
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
||||
False means to mask out. (B, Sk)
|
||||
"""
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
causal = self.causal if causal is None else causal
|
||||
seqlen_k = kv.shape[1]
|
||||
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
||||
k, v = kv.unbind(dim=2)
|
||||
@ -254,7 +263,7 @@ class CrossAttention(nn.Module):
|
||||
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
|
||||
if self.causal:
|
||||
if causal:
|
||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||
# So we have to construct the mask in float
|
||||
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
|
||||
@ -280,7 +289,7 @@ class MHA(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
|
||||
softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0,
|
||||
softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0,
|
||||
rotary_emb_scale_base=0,
|
||||
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
|
||||
checkpointing=False, device=None, dtype=None) -> None:
|
||||
@ -294,6 +303,7 @@ class MHA(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
self.cross_attn = cross_attn
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
self.dwconv = dwconv
|
||||
self.rotary_emb_dim = rotary_emb_dim
|
||||
self.use_flash_attn = use_flash_attn
|
||||
@ -315,6 +325,8 @@ class MHA(nn.Module):
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
linear_resid_cls = (LinearResidual if not fused_bias_fc
|
||||
else partial(FusedDense, return_residual=True))
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
if not self.cross_attn:
|
||||
if not self.return_residual:
|
||||
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
|
||||
@ -323,7 +335,6 @@ class MHA(nn.Module):
|
||||
if self.dwconv:
|
||||
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
|
||||
groups=3 * embed_dim)
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
else:
|
||||
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
if not self.return_residual:
|
||||
@ -335,14 +346,41 @@ class MHA(nn.Module):
|
||||
groups=embed_dim)
|
||||
self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2,
|
||||
groups=2 * embed_dim)
|
||||
inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
# output projection always have the bias (for now)
|
||||
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
|
||||
|
||||
def _update_kv_cache(self, kv, inference_params):
|
||||
"""kv: (batch_size, 1, nheads, head_dim)
|
||||
"""
|
||||
assert not self.dwconv, 'Generation does not support dwconv yet'
|
||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
inference_kv_cache = torch.empty(
|
||||
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
|
||||
self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device
|
||||
)
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = inference_kv_cache
|
||||
else:
|
||||
inference_kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
assert batch_end <= inference_kv_cache.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert sequence_end <= inference_kv_cache.shape[1]
|
||||
# Copy key and values.
|
||||
inference_kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = inference_kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
return kv
|
||||
|
||||
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
||||
**kwargs):
|
||||
inference_params=None, **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
||||
@ -355,6 +393,8 @@ class MHA(nn.Module):
|
||||
max_seqlen: int. Maximum sequence length in the batch.
|
||||
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
||||
(batch, seqlen). Only applicable when not using FlashAttention.
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
"""
|
||||
if cu_seqlens is not None:
|
||||
assert max_seqlen is not None
|
||||
@ -366,6 +406,10 @@ class MHA(nn.Module):
|
||||
assert cu_seqlens is None
|
||||
assert max_seqlen is None
|
||||
assert not self.use_flash_attn
|
||||
if inference_params is not None:
|
||||
assert key_padding_mask is None
|
||||
assert cu_seqlens is None and max_seqlen is None
|
||||
assert not self.dwconv
|
||||
|
||||
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
|
||||
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
|
||||
@ -378,12 +422,22 @@ class MHA(nn.Module):
|
||||
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
|
||||
'b d s -> b s d').contiguous()
|
||||
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv)
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
if inference_params is None:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv)
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
|
||||
q = qkv[:, :, 0]
|
||||
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = False if inference_params.sequence_len_offset == 0 else None
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
if not self.return_residual:
|
||||
q = self.Wq(x)
|
||||
@ -401,10 +455,14 @@ class MHA(nn.Module):
|
||||
'b d s -> b s d').contiguous()
|
||||
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
|
||||
'b d s -> b s d').contiguous()
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(q, kv, **kwargs)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(q, kv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
|
||||
kv = self._update_kv_cache(kv)
|
||||
context = self.inner_cross_attn(q, kv, causal=False)
|
||||
out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
|
||||
return out if not self.return_residual else (out, x)
|
||||
|
||||
|
||||
64
flash_attn/utils/generation.py
Normal file
64
flash_attn/utils/generation.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
||||
from dataclasses import dataclass, field
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers.generation import GreedySearchDecoderOnlyOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
max_sequence_len: int
|
||||
max_batch_size: int
|
||||
sequence_len_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def greedy_decode(input_ids, model, max_length):
|
||||
"""Greedy decoding. This is a very simple implementation.
|
||||
We assume that all sequences in the same batch have the same length.
|
||||
Arguments:
|
||||
input_ids: (batch, seq_len)
|
||||
max_length: int
|
||||
Returns: GreedySearchDecoderOnlyOutput, with the following fields:
|
||||
sequences: (batch, max_length)
|
||||
scores: tuples of (batch, vocab_size)
|
||||
"""
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||
scores.append(logits)
|
||||
next_token = logits.argmax(dim=-1)
|
||||
sequences = [next_token]
|
||||
inference_params.sequence_len_offset = seqlen_og
|
||||
while True:
|
||||
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
|
||||
dtype=torch.long, device=input_ids.device)
|
||||
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
scores.append(logits)
|
||||
next_token = logits.argmax(dim=-1)
|
||||
sequences.append(next_token)
|
||||
inference_params.sequence_len_offset += 1
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
break
|
||||
return GreedySearchDecoderOnlyOutput(
|
||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
||||
scores=tuple(scores)
|
||||
)
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
|
||||
def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False):
|
||||
output = greedy_decode(input_ids, self, max_length)
|
||||
if not output_scores:
|
||||
output.scores = None
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
@ -23,16 +23,6 @@ def test_gpt2_state_dict(model_name):
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
def get_hf_models(model_name, config, dtype):
|
||||
pretrained_state_dict = state_dict_from_pretrained(model_name)
|
||||
model_hf = GPT2LMHeadModelHF(config)
|
||||
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
|
||||
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
|
||||
model_hf.load_state_dict(pretrained_state_dict, strict=False)
|
||||
model_hf.cuda().to(dtype=dtype)
|
||||
return model_hf
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["gpt2", "gpt2-medium"])
|
||||
# @pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
def test_gpt2_non_optimized(model_name):
|
||||
|
||||
83
tests/models/test_gpt_generation.py
Normal file
83
tests/models/test_gpt_generation.py
Normal file
@ -0,0 +1,83 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config, GPT2Tokenizer
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.gpt import remap_state_dict_gpt2
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import greedy_decode
|
||||
|
||||
|
||||
# TODO: test with rotary embedding
|
||||
@pytest.mark.parametrize('optimized', [False, True])
|
||||
# @pytest.mark.parametrize('optimized', [False])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
def test_greedy_decode(model_name, optimized):
|
||||
"""Check that our implementation of GPT2 generation matches the HF implementation:
|
||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||
the HF scores in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
rtol, atol = 3e-3, 3e-1
|
||||
config = GPT2Config.from_pretrained(model_name)
|
||||
if optimized:
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dense_gelu_dense = True
|
||||
config.fused_dropout_add_ln = True
|
||||
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config)
|
||||
model = model.cuda().to(dtype=dtype)
|
||||
|
||||
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
|
||||
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
|
||||
|
||||
model.eval()
|
||||
model_ref.eval()
|
||||
model_hf.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
|
||||
max_length = 30
|
||||
|
||||
# Slow generation for reference
|
||||
sequences = []
|
||||
scores = []
|
||||
cur_input_ids = input_ids
|
||||
with torch.inference_mode():
|
||||
scores.append(model(cur_input_ids).logits[:, -1])
|
||||
sequences.append(scores[-1].argmax(dim=-1))
|
||||
for _ in range(input_ids.shape[1] + 1, max_length):
|
||||
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
|
||||
scores.append(model(cur_input_ids).logits[:, -1])
|
||||
sequences.append(scores[-1].argmax(dim=-1))
|
||||
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
|
||||
scores = tuple(scores)
|
||||
|
||||
out = model.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
|
||||
assert torch.all(out.sequences == sequences)
|
||||
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
|
||||
rtol=rtol, atol=atol)
|
||||
assert torch.all(out.sequences == out_ref.sequences)
|
||||
assert torch.all(out.sequences == out_hf.sequences)
|
||||
|
||||
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
|
||||
Loading…
Reference in New Issue
Block a user