[Gen] Fix FT kernel when using CG
This commit is contained in:
parent
dceb2687c5
commit
605655bc66
@ -495,7 +495,8 @@ class MHA(nn.Module):
|
|||||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||||
self.rotary_emb_dim,
|
self.rotary_emb_dim,
|
||||||
not self.rotary_emb.interleaved # neox_rotary_style
|
# neox_rotary_style
|
||||||
|
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
|
||||||
)
|
)
|
||||||
context = rearrange(context, 'b h d -> b 1 h d')
|
context = rearrange(context, 'b h d -> b 1 h d')
|
||||||
else:
|
else:
|
||||||
@ -609,7 +610,8 @@ class ParallelMHA(nn.Module):
|
|||||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||||
self.rotary_emb_dim,
|
self.rotary_emb_dim,
|
||||||
not self.rotary_emb.interleaved # neox_rotary_style
|
# neox_rotary_style
|
||||||
|
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
|
||||||
)
|
)
|
||||||
context = rearrange(context, 'b h d -> b 1 h d')
|
context = rearrange(context, 'b h d -> b 1 h d')
|
||||||
if seqlen is None:
|
if seqlen is None:
|
||||||
|
|||||||
@ -190,9 +190,9 @@ def seqlen_to_seqlen_type(seqlen: int) -> int:
|
|||||||
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
|
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
|
||||||
|
|
||||||
|
|
||||||
def seqlen_type_to_seqlen(seqlen_type: int) -> int:
|
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
|
||||||
assert seqlen_type in [0, 1, 2]
|
assert seqlen_type in [0, 1, 2]
|
||||||
return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048)
|
return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -239,9 +239,9 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
|||||||
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
||||||
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
||||||
if s_type not in cache.callables:
|
if s_type not in cache.callables:
|
||||||
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
|
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
|
||||||
cache.callables[s_type] = capture_graph(
|
cache.callables[s_type] = capture_graph(
|
||||||
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool,
|
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
|
||||||
n_warmups=n_warmups
|
n_warmups=n_warmups
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -249,17 +249,19 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
|||||||
return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
||||||
|
|
||||||
cache.run = dispatch
|
cache.run = dispatch
|
||||||
cache.inference_params.sequence_length_offset = 0 # Reset so it's not confusing
|
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None,
|
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
|
||||||
n_warmups=2):
|
|
||||||
assert max_seqlen >= seqlen_og
|
|
||||||
device = next(iter(model.parameters())).device
|
device = next(iter(model.parameters())).device
|
||||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||||
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||||
inference_params.lengths_per_sample[:] = seqlen_og
|
sequence_len_offset_og = inference_params.sequence_len_offset
|
||||||
|
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
|
||||||
|
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
|
||||||
|
inference_params.sequence_len_offset = max_seqlen - 1
|
||||||
|
inference_params.lengths_per_sample[:] = max_seqlen - 1
|
||||||
|
|
||||||
# Warmup before capture
|
# Warmup before capture
|
||||||
s = torch.cuda.Stream()
|
s = torch.cuda.Stream()
|
||||||
@ -289,4 +291,5 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
|
|||||||
graph.replay()
|
graph.replay()
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
inference_params.sequence_len_offset = sequence_len_offset_og
|
||||||
return run
|
return run
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import re
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
@ -9,6 +9,7 @@ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
|||||||
from flash_attn.models.gpt import GPTLMHeadModel
|
from flash_attn.models.gpt import GPTLMHeadModel
|
||||||
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config
|
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config
|
||||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||||
|
from flash_attn.utils.generation import update_graph_cache
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
|
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
|
||||||
@ -79,3 +80,92 @@ def test_gptj_optimized(model_name):
|
|||||||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
||||||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
||||||
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
|
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
|
||||||
|
def test_gptj_generation(model_name):
|
||||||
|
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
|
||||||
|
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||||
|
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||||
|
"""
|
||||||
|
dtype = torch.float16
|
||||||
|
device = 'cuda'
|
||||||
|
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
|
||||||
|
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
|
||||||
|
config.fused_bias_fc = True
|
||||||
|
config.fused_mlp = True
|
||||||
|
config.fused_dropout_add_ln = True
|
||||||
|
# Only prenorm supports residual_in_fp32
|
||||||
|
config.residual_in_fp32 = True
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
batch_size = 1
|
||||||
|
seqlen = 100
|
||||||
|
max_length = 150
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
|
||||||
|
device_map={"": device})
|
||||||
|
model_hf.eval()
|
||||||
|
print("HF fp16")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||||
|
return_dict_in_generate=True, output_scores=True)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||||
|
del model_hf
|
||||||
|
|
||||||
|
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
|
||||||
|
model_ref.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
|
||||||
|
del model_ref
|
||||||
|
|
||||||
|
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
print('Without CUDA graph')
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
out = model.generate(input_ids=input_ids, max_length=max_length,
|
||||||
|
eos_token_id=eos_token_id, fused_ft_kernel=True,
|
||||||
|
# eos_token_id=eos_token_id, fused_ft_kernel=False,
|
||||||
|
return_dict_in_generate=True, output_scores=True, timing=True,
|
||||||
|
teacher_outputs=out_hf.sequences)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||||
|
|
||||||
|
# Capture graph outside the timing loop
|
||||||
|
batch_size, seqlen_og = input_ids.shape
|
||||||
|
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||||
|
print('With CUDA graph')
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
|
||||||
|
fused_ft_kernel=True, cg=True,
|
||||||
|
return_dict_in_generate=True, output_scores=True, timing=True,
|
||||||
|
teacher_outputs=out_hf.sequences)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
|
||||||
|
logits_hf = torch.stack(out_hf.scores, dim=1)
|
||||||
|
logits = torch.stack(out.scores, dim=1)
|
||||||
|
logits_cg = torch.stack(out_cg.scores, dim=1)
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
|
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||||
|
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
|
||||||
|
|
||||||
|
print(f'HF fp16 logits max diff: {hf_error}')
|
||||||
|
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
|
||||||
|
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||||
|
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
|
||||||
|
assert torch.equal(logits_cg, logits)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user