[Gen] Add option to run generation with FT attention kernel
This commit is contained in:
parent
be1afaa276
commit
a668890fcd
@ -30,6 +30,11 @@ try:
|
||||
except ImportError:
|
||||
RotaryEmbedding = None
|
||||
|
||||
try:
|
||||
import ft_attention
|
||||
except ImportError:
|
||||
ft_attention = None
|
||||
|
||||
|
||||
class FlashSelfAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
@ -360,23 +365,32 @@ class MHA(nn.Module):
|
||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
inference_kv_cache = torch.empty(
|
||||
kv_cache = torch.empty(
|
||||
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
|
||||
self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device
|
||||
)
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = inference_kv_cache
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = kv_cache
|
||||
else:
|
||||
inference_kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
assert not inference_params.fused_ft_kernel, 'fused_ft_kernel should not take this path'
|
||||
kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
assert batch_end <= inference_kv_cache.shape[0]
|
||||
assert batch_end <= kv_cache.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert sequence_end <= inference_kv_cache.shape[1]
|
||||
assert sequence_end <= kv_cache.shape[1]
|
||||
# Copy key and values.
|
||||
inference_kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = inference_kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
if inference_params.fused_ft_kernel:
|
||||
# FT kernel requires different layouts for the k_cache and v_cache.
|
||||
assert kv_cache.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if kv_cache.dtype == torch.float32 else 8
|
||||
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
|
||||
packsize=packsize).contiguous()
|
||||
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache)
|
||||
return kv
|
||||
|
||||
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
||||
@ -430,14 +444,24 @@ class MHA(nn.Module):
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
||||
else:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
|
||||
q = qkv[:, :, 0]
|
||||
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
|
||||
q = qkv[:, :, 0]
|
||||
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
assert ft_attention is not None
|
||||
context = ft_attention.single_query_attention(
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim
|
||||
)
|
||||
context = rearrange(context, 'b h d -> b 1 h d')
|
||||
else:
|
||||
if not self.return_residual:
|
||||
q = self.Wq(x)
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
||||
from typing import Optional
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
@ -17,9 +20,11 @@ class InferenceParams:
|
||||
sequence_len_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
fused_ft_kernel: bool = False
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
|
||||
def greedy_decode(input_ids, model, max_length):
|
||||
def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
|
||||
"""Greedy decoding. This is a very simple implementation.
|
||||
We assume that all sequences in the same batch have the same length.
|
||||
Arguments:
|
||||
@ -30,7 +35,8 @@ def greedy_decode(input_ids, model, max_length):
|
||||
scores: tuples of (batch, vocab_size)
|
||||
"""
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
||||
fused_ft_kernel=fused_ft_kernel)
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||
@ -57,8 +63,9 @@ def greedy_decode(input_ids, model, max_length):
|
||||
|
||||
class GenerationMixin:
|
||||
|
||||
def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False):
|
||||
output = greedy_decode(input_ids, self, max_length)
|
||||
def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False,
|
||||
**kwargs):
|
||||
output = greedy_decode(input_ids, self, max_length, **kwargs)
|
||||
if not output_scores:
|
||||
output.scores = None
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
|
||||
@ -15,10 +15,11 @@ from flash_attn.utils.generation import greedy_decode
|
||||
|
||||
|
||||
# TODO: test with rotary embedding
|
||||
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
||||
@pytest.mark.parametrize('optimized', [False, True])
|
||||
# @pytest.mark.parametrize('optimized', [False])
|
||||
# @pytest.mark.parametrize('optimized', [True])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
def test_greedy_decode(model_name, optimized):
|
||||
def test_greedy_decode(model_name, optimized, fused_ft_kernel):
|
||||
"""Check that our implementation of GPT2 generation matches the HF implementation:
|
||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||
the HF scores in fp32.
|
||||
@ -62,6 +63,7 @@ def test_greedy_decode(model_name, optimized):
|
||||
scores = tuple(scores)
|
||||
|
||||
out = model.generate(input_ids=input_ids, max_length=max_length,
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user