[Gen] Fix calling update_graph_cache in tests
This commit is contained in:
parent
4c91621a5e
commit
8a733cbd53
@ -659,8 +659,7 @@ class MHA(nn.Module):
|
||||
qkv = rearrange(
|
||||
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
||||
).contiguous()
|
||||
# qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
||||
qkv = qkv.reshape(batch, seqlen, 3, self.num_heads, self.head_dim)
|
||||
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
||||
if (
|
||||
inference_params is None
|
||||
or inference_params.sequence_len_offset == 0
|
||||
@ -700,10 +699,8 @@ class MHA(nn.Module):
|
||||
qkv, x = self.Wqkv(x)
|
||||
q = qkv[..., : self.num_heads * self.head_dim]
|
||||
kv = qkv[..., self.num_heads * self.head_dim :]
|
||||
# q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
||||
q = q.reshape(batch, seqlen, -1, self.head_dim)
|
||||
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
||||
kv = kv.reshape(batch, seqlen, 2, -1, self.head_dim)
|
||||
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
||||
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
||||
if self.dwconv:
|
||||
q = rearrange(
|
||||
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
||||
@ -731,8 +728,7 @@ class MHA(nn.Module):
|
||||
context = self._update_kvcache_attention(q, kv, inference_params)
|
||||
else:
|
||||
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
||||
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
||||
out = self.out_proj(context.reshape(batch, seqlen, -1))
|
||||
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
||||
return out if not self.return_residual else (out, x)
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -404,7 +404,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, None, batch_size, seqlen_og, max_length
|
||||
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=False
|
||||
)
|
||||
print("With CUDA graph")
|
||||
out_cg = model.generate(
|
||||
|
||||
@ -253,7 +253,9 @@ def test_falcon_generation(model_name):
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
|
||||
)
|
||||
print("With CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
@ -356,7 +358,9 @@ def test_falcon_parallel_generation(model_name, world_size):
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
|
||||
)
|
||||
print("With CUDA graph")
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -6,7 +6,6 @@ import pytest
|
||||
import torch
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
|
||||
from flash_attn.utils.generation import update_graph_cache
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from transformers import AutoTokenizer, GPTNeoXConfig
|
||||
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
||||
|
||||
@ -83,8 +83,9 @@ def test_gptj_optimized(model_name):
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
|
||||
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
|
||||
def test_gptj_generation(model_name):
|
||||
def test_gptj_generation(model_name, fused_ft_kernel):
|
||||
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
@ -140,8 +141,7 @@ def test_gptj_generation(model_name):
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
fused_ft_kernel=True,
|
||||
# eos_token_id=eos_token_id, fused_ft_kernel=False,
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
enable_timing=True,
|
||||
@ -152,14 +152,16 @@ def test_gptj_generation(model_name):
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
|
||||
)
|
||||
print("With CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
fused_ft_kernel=True,
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
|
||||
@ -303,7 +303,9 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
|
||||
)
|
||||
print("With CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
@ -408,7 +410,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
|
||||
)
|
||||
print("With CUDA graph")
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -168,7 +168,9 @@ def test_opt_generation(model_name):
|
||||
if fused_ft_kernel:
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
|
||||
)
|
||||
print("With CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user