[Gen] Fix calling update_graph_cache in tests

This commit is contained in:
Tri Dao 2023-09-10 17:22:37 -07:00
parent 4c91621a5e
commit 8a733cbd53
8 changed files with 29 additions and 20 deletions

View File

@ -659,8 +659,7 @@ class MHA(nn.Module):
qkv = rearrange(
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
).contiguous()
# qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
qkv = qkv.reshape(batch, seqlen, 3, self.num_heads, self.head_dim)
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
if (
inference_params is None
or inference_params.sequence_len_offset == 0
@ -700,10 +699,8 @@ class MHA(nn.Module):
qkv, x = self.Wqkv(x)
q = qkv[..., : self.num_heads * self.head_dim]
kv = qkv[..., self.num_heads * self.head_dim :]
# q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
q = q.reshape(batch, seqlen, -1, self.head_dim)
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
kv = kv.reshape(batch, seqlen, 2, -1, self.head_dim)
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
if self.dwconv:
q = rearrange(
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
@ -731,8 +728,7 @@ class MHA(nn.Module):
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
out = self.out_proj(context.reshape(batch, seqlen, -1))
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
return out if not self.return_residual else (out, x)

View File

@ -1,3 +1,5 @@
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union
import torch

View File

@ -404,7 +404,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=False
)
print("With CUDA graph")
out_cg = model.generate(

View File

@ -253,7 +253,9 @@ def test_falcon_generation(model_name):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
@ -356,7 +358,9 @@ def test_falcon_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids,

View File

@ -6,7 +6,6 @@ import pytest
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTNeoXConfig
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM

View File

@ -83,8 +83,9 @@ def test_gptj_optimized(model_name):
).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name):
def test_gptj_generation(model_name, fused_ft_kernel):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
@ -140,8 +141,7 @@ def test_gptj_generation(model_name):
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
# eos_token_id=eos_token_id, fused_ft_kernel=False,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
@ -152,14 +152,16 @@ def test_gptj_generation(model_name):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,

View File

@ -303,7 +303,9 @@ def test_llama_generation(model_name, checkpoint_format):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
@ -408,7 +410,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids,

View File

@ -168,7 +168,9 @@ def test_opt_generation(model_name):
if fused_ft_kernel:
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()