[GPT] Enable FlashAttention for GPT-J
This commit is contained in:
parent
6fc1e07da2
commit
b3177dfaf6
@ -276,6 +276,9 @@ class ParallelBlock(nn.Module):
|
||||
for p in self.norm2.parameters():
|
||||
p._shared_params = True
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
|
||||
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
|
||||
residual: Optional[Tensor] = None, mixer_kwargs=None):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
|
||||
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
@ -93,7 +93,7 @@ def test_gptj_generation(model_name):
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
|
||||
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user