[GPT] Enable FlashAttention for GPT-J

This commit is contained in:
Tri Dao 2023-07-21 17:29:10 -07:00
parent 6fc1e07da2
commit b3177dfaf6
2 changed files with 5 additions and 2 deletions

View File

@ -276,6 +276,9 @@ class ParallelBlock(nn.Module):
for p in self.norm2.parameters():
p._shared_params = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None, mixer_kwargs=None):
r"""Pass the input through the encoder layer.

View File

@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
dtype = torch.float16
device = 'cuda'
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
@ -93,7 +93,7 @@ def test_gptj_generation(model_name):
dtype = torch.float16
device = 'cuda'
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True