From 8a733cbd538ef59aacfd60bc44a5262dbe8f5768 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 10 Sep 2023 17:22:37 -0700 Subject: [PATCH] [Gen] Fix calling update_graph_cache in tests --- flash_attn/modules/mha.py | 12 ++++-------- flash_attn/ops/triton/rotary.py | 2 ++ tests/models/test_baichuan.py | 2 +- tests/models/test_falcon.py | 8 ++++++-- tests/models/test_gpt_neox.py | 1 - tests/models/test_gptj.py | 12 +++++++----- tests/models/test_llama.py | 8 ++++++-- tests/models/test_opt.py | 4 +++- 8 files changed, 29 insertions(+), 20 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 3d6b707..39b1254 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -659,8 +659,7 @@ class MHA(nn.Module): qkv = rearrange( self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" ).contiguous() - # qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) - qkv = qkv.reshape(batch, seqlen, 3, self.num_heads, self.head_dim) + qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) if ( inference_params is None or inference_params.sequence_len_offset == 0 @@ -700,10 +699,8 @@ class MHA(nn.Module): qkv, x = self.Wqkv(x) q = qkv[..., : self.num_heads * self.head_dim] kv = qkv[..., self.num_heads * self.head_dim :] - # q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) - q = q.reshape(batch, seqlen, -1, self.head_dim) - # kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) - kv = kv.reshape(batch, seqlen, 2, -1, self.head_dim) + q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) + kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) if self.dwconv: q = rearrange( self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" @@ -731,8 +728,7 @@ class MHA(nn.Module): context = self._update_kvcache_attention(q, kv, inference_params) else: context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) - # out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) - out = self.out_proj(context.reshape(batch, seqlen, -1)) + out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) return out if not self.return_residual else (out, x) diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index 0e9b566..8d2e09b 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -1,3 +1,5 @@ +# Copyright (c) 2023, Tri Dao. + from typing import Optional, Union import torch diff --git a/tests/models/test_baichuan.py b/tests/models/test_baichuan.py index 3818f30..464e32e 100644 --- a/tests/models/test_baichuan.py +++ b/tests/models/test_baichuan.py @@ -404,7 +404,7 @@ def test_baichuan_parallel_generation(model_name, world_size): # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache( - model, None, batch_size, seqlen_og, max_length + model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=False ) print("With CUDA graph") out_cg = model.generate( diff --git a/tests/models/test_falcon.py b/tests/models/test_falcon.py index ecb95fd..dfeb544 100644 --- a/tests/models/test_falcon.py +++ b/tests/models/test_falcon.py @@ -253,7 +253,9 @@ def test_falcon_generation(model_name): # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape - model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + model._decoding_cache = update_graph_cache( + model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True + ) print("With CUDA graph") torch.cuda.synchronize() start = time.time() @@ -356,7 +358,9 @@ def test_falcon_parallel_generation(model_name, world_size): # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape - model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + model._decoding_cache = update_graph_cache( + model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True + ) print("With CUDA graph") out_cg = model.generate( input_ids=input_ids, diff --git a/tests/models/test_gpt_neox.py b/tests/models/test_gpt_neox.py index 65a937b..f4e27da 100644 --- a/tests/models/test_gpt_neox.py +++ b/tests/models/test_gpt_neox.py @@ -6,7 +6,6 @@ import pytest import torch from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox -from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import AutoTokenizer, GPTNeoXConfig from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index 8abb3b9..d31aea4 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -83,8 +83,9 @@ def test_gptj_optimized(model_name): ).abs().max().item() +@pytest.mark.parametrize("fused_ft_kernel", [False, True]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"]) -def test_gptj_generation(model_name): +def test_gptj_generation(model_name, fused_ft_kernel): """Check that our implementation of GPT-J (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. @@ -140,8 +141,7 @@ def test_gptj_generation(model_name): input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, - fused_ft_kernel=True, - # eos_token_id=eos_token_id, fused_ft_kernel=False, + fused_ft_kernel=fused_ft_kernel, return_dict_in_generate=True, output_scores=True, enable_timing=True, @@ -152,14 +152,16 @@ def test_gptj_generation(model_name): # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape - model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + model._decoding_cache = update_graph_cache( + model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel + ) print("With CUDA graph") torch.cuda.synchronize() start = time.time() out_cg = model.generate( input_ids=input_ids, max_length=max_length, - fused_ft_kernel=True, + fused_ft_kernel=fused_ft_kernel, cg=True, return_dict_in_generate=True, output_scores=True, diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index 3b162ba..e4a5674 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -303,7 +303,9 @@ def test_llama_generation(model_name, checkpoint_format): # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape - model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + model._decoding_cache = update_graph_cache( + model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True + ) print("With CUDA graph") torch.cuda.synchronize() start = time.time() @@ -408,7 +410,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape - model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + model._decoding_cache = update_graph_cache( + model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True + ) print("With CUDA graph") out_cg = model.generate( input_ids=input_ids, diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py index 535b76c..6378b9c 100644 --- a/tests/models/test_opt.py +++ b/tests/models/test_opt.py @@ -168,7 +168,9 @@ def test_opt_generation(model_name): if fused_ft_kernel: # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape - model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + model._decoding_cache = update_graph_cache( + model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True + ) print("With CUDA graph") torch.cuda.synchronize() start = time.time()