[Gen] Fix FT kernel smem size, CG when batch size changed
This commit is contained in:
parent
96d10f6545
commit
311d6606bf
@ -29,14 +29,13 @@
|
||||
|
||||
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
|
||||
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
|
||||
dim3 grid(params.num_heads, params.batch_size); \
|
||||
mmha::masked_multihead_attention_kernel<T, \
|
||||
Dh, \
|
||||
Dh_MAX, \
|
||||
THDS_PER_KEY, \
|
||||
THDS_PER_VALUE, \
|
||||
THDS_PER_BLOCK, \
|
||||
DO_CROSS_ATTENTION><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
|
||||
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
|
||||
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
|
||||
if (smem_sz >= 48 * 1024) { \
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
|
||||
} \
|
||||
dim3 grid(params.num_heads, params.batch_size); \
|
||||
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -490,10 +490,15 @@ class MHA(nn.Module):
|
||||
else:
|
||||
assert inference_params.fused_ft_kernel
|
||||
assert ft_attention is not None
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + qkv.shape[0]
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
|
||||
if inference_params.lengths_per_sample is not None else None)
|
||||
context = ft_attention.single_query_attention(
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
|
||||
lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim,
|
||||
# neox_rotary_style
|
||||
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
|
||||
@ -605,11 +610,16 @@ class ParallelMHA(nn.Module):
|
||||
else:
|
||||
assert inference_params.fused_ft_kernel
|
||||
assert ft_attention is not None
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + qkv.shape[0]
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
|
||||
if inference_params.lengths_per_sample is not None else None)
|
||||
context = ft_attention.single_query_attention(
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim,
|
||||
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
|
||||
lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim, inference_params.sequence_len_offset,
|
||||
# neox_rotary_style
|
||||
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
|
||||
)
|
||||
|
||||
@ -238,15 +238,16 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
||||
)
|
||||
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
||||
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
||||
if s_type not in cache.callables:
|
||||
if (batch_size, s_type) not in cache.callables:
|
||||
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
|
||||
cache.callables[s_type] = capture_graph(
|
||||
cache.callables[batch_size, s_type] = capture_graph(
|
||||
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
|
||||
n_warmups=n_warmups
|
||||
)
|
||||
|
||||
def dispatch(input_ids, position_ids, seqlen):
|
||||
return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
||||
batch_size = input_ids.shape[0]
|
||||
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
||||
|
||||
cache.run = dispatch
|
||||
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
|
||||
|
||||
80
tests/models/test_gpt_generation_cg.py
Normal file
80
tests/models/test_gpt_generation_cg.py
Normal file
@ -0,0 +1,80 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.utils.generation import update_graph_cache
|
||||
|
||||
|
||||
def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
|
||||
out = model.generate(input_ids=input_ids, max_length=max_length, fused_ft_kernel=True,
|
||||
teacher_outputs=teacher_outputs, return_dict_in_generate=True,
|
||||
output_scores=True, timing=True, **kwargs)
|
||||
return torch.stack(out.scores, dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('seqlen,maxlen', [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
|
||||
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
|
||||
@pytest.mark.parametrize('rotary', [None, "interleaved", "block"])
|
||||
# @pytest.mark.parametrize('rotary', [None])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
|
||||
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
rtol, atol = 3e-3, 3e-1
|
||||
config = GPT2Config.from_pretrained(model_name)
|
||||
config.n_positions = 16 * 1024
|
||||
assert seqlen <= maxlen <= config.n_positions
|
||||
if rotary is not None:
|
||||
config.n_positions = 0
|
||||
config.rotary_emb_dim = 32
|
||||
config.rotary_emb_interleaved = rotary == "interleaved"
|
||||
config.residual_in_fp32 = True
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 1
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
|
||||
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
|
||||
assert torch.equal(logits, logits_cg)
|
||||
|
||||
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
|
||||
batch_size = 3
|
||||
maxlen += 30
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
|
||||
device=device)
|
||||
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
|
||||
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
|
||||
assert torch.equal(logits, logits_cg)
|
||||
|
||||
batch_size = 2
|
||||
maxlen -= 35
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
|
||||
device=device)
|
||||
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
|
||||
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
|
||||
assert torch.equal(logits, logits_cg)
|
||||
Loading…
Reference in New Issue
Block a user