[Gen] Fix FT kernel when using CG
This commit is contained in:
parent
dceb2687c5
commit
605655bc66
@ -495,7 +495,8 @@ class MHA(nn.Module):
|
||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim,
|
||||
not self.rotary_emb.interleaved # neox_rotary_style
|
||||
# neox_rotary_style
|
||||
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
|
||||
)
|
||||
context = rearrange(context, 'b h d -> b 1 h d')
|
||||
else:
|
||||
@ -609,7 +610,8 @@ class ParallelMHA(nn.Module):
|
||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim,
|
||||
not self.rotary_emb.interleaved # neox_rotary_style
|
||||
# neox_rotary_style
|
||||
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
|
||||
)
|
||||
context = rearrange(context, 'b h d -> b 1 h d')
|
||||
if seqlen is None:
|
||||
|
||||
@ -190,9 +190,9 @@ def seqlen_to_seqlen_type(seqlen: int) -> int:
|
||||
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
|
||||
|
||||
|
||||
def seqlen_type_to_seqlen(seqlen_type: int) -> int:
|
||||
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
|
||||
assert seqlen_type in [0, 1, 2]
|
||||
return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048)
|
||||
return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -239,9 +239,9 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
||||
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
||||
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
||||
if s_type not in cache.callables:
|
||||
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
|
||||
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
|
||||
cache.callables[s_type] = capture_graph(
|
||||
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool,
|
||||
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
|
||||
n_warmups=n_warmups
|
||||
)
|
||||
|
||||
@ -249,17 +249,19 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
|
||||
return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
||||
|
||||
cache.run = dispatch
|
||||
cache.inference_params.sequence_length_offset = 0 # Reset so it's not confusing
|
||||
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
|
||||
return cache
|
||||
|
||||
|
||||
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None,
|
||||
n_warmups=2):
|
||||
assert max_seqlen >= seqlen_og
|
||||
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
|
||||
device = next(iter(model.parameters())).device
|
||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
inference_params.lengths_per_sample[:] = seqlen_og
|
||||
sequence_len_offset_og = inference_params.sequence_len_offset
|
||||
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
|
||||
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
|
||||
inference_params.sequence_len_offset = max_seqlen - 1
|
||||
inference_params.lengths_per_sample[:] = max_seqlen - 1
|
||||
|
||||
# Warmup before capture
|
||||
s = torch.cuda.Stream()
|
||||
@ -289,4 +291,5 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
|
||||
graph.replay()
|
||||
return logits
|
||||
|
||||
inference_params.sequence_len_offset = sequence_len_offset_og
|
||||
return run
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import re
|
||||
import time
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
@ -9,6 +9,7 @@ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import update_graph_cache
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
|
||||
@ -79,3 +80,92 @@ def test_gptj_optimized(model_name):
|
||||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
|
||||
def test_gptj_generation(model_name):
|
||||
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
# Only prenorm supports residual_in_fp32
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 1
|
||||
seqlen = 100
|
||||
max_length = 150
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
|
||||
device_map={"": device})
|
||||
model_hf.eval()
|
||||
print("HF fp16")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
del model_hf
|
||||
|
||||
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
|
||||
del model_ref
|
||||
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
print('Without CUDA graph')
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out = model.generate(input_ids=input_ids, max_length=max_length,
|
||||
eos_token_id=eos_token_id, fused_ft_kernel=True,
|
||||
# eos_token_id=eos_token_id, fused_ft_kernel=False,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True,
|
||||
teacher_outputs=out_hf.sequences)
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
print('With CUDA graph')
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
|
||||
fused_ft_kernel=True, cg=True,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True,
|
||||
teacher_outputs=out_hf.sequences)
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
|
||||
with torch.no_grad():
|
||||
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
|
||||
logits_hf = torch.stack(out_hf.scores, dim=1)
|
||||
logits = torch.stack(out.scores, dim=1)
|
||||
logits_cg = torch.stack(out_cg.scores, dim=1)
|
||||
|
||||
del model
|
||||
|
||||
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
|
||||
|
||||
print(f'HF fp16 logits max diff: {hf_error}')
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
|
||||
assert torch.equal(logits_cg, logits)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user