From b3177dfaf696ee522a495bcb48b88d32167aa17f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 21 Jul 2023 17:29:10 -0700 Subject: [PATCH] [GPT] Enable FlashAttention for GPT-J --- flash_attn/modules/block.py | 3 +++ tests/models/test_gptj.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index a4ff5a2..e19742d 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -276,6 +276,9 @@ class ParallelBlock(nn.Module): for p in self.norm2.parameters(): p._shared_params = True + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None, residual: Optional[Tensor] = None, mixer_kwargs=None): r"""Pass the input through the encoder layer. diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index b27dde0..6f0c210 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -36,7 +36,7 @@ def test_gptj_optimized(model_name): dtype = torch.float16 device = 'cuda' config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) - config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet + config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True @@ -93,7 +93,7 @@ def test_gptj_generation(model_name): dtype = torch.float16 device = 'cuda' config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) - config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet + config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True