diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index a4ff5a2..e19742d 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -276,6 +276,9 @@ class ParallelBlock(nn.Module): for p in self.norm2.parameters(): p._shared_params = True + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None, residual: Optional[Tensor] = None, mixer_kwargs=None): r"""Pass the input through the encoder layer. diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index b27dde0..6f0c210 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -36,7 +36,7 @@ def test_gptj_optimized(model_name): dtype = torch.float16 device = 'cuda' config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) - config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet + config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True @@ -93,7 +93,7 @@ def test_gptj_generation(model_name): dtype = torch.float16 device = 'cuda' config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) - config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet + config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True