[GPT] Enable FlashAttention for GPT-J
This commit is contained in:
parent
6fc1e07da2
commit
b3177dfaf6
@ -276,6 +276,9 @@ class ParallelBlock(nn.Module):
|
|||||||
for p in self.norm2.parameters():
|
for p in self.norm2.parameters():
|
||||||
p._shared_params = True
|
p._shared_params = True
|
||||||
|
|
||||||
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||||
|
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||||
|
|
||||||
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
|
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
|
||||||
residual: Optional[Tensor] = None, mixer_kwargs=None):
|
residual: Optional[Tensor] = None, mixer_kwargs=None):
|
||||||
r"""Pass the input through the encoder layer.
|
r"""Pass the input through the encoder layer.
|
||||||
|
|||||||
@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
|
|||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
|
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
|
||||||
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
|
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||||
config.fused_bias_fc = True
|
config.fused_bias_fc = True
|
||||||
config.fused_mlp = True
|
config.fused_mlp = True
|
||||||
config.fused_dropout_add_ln = True
|
config.fused_dropout_add_ln = True
|
||||||
@ -93,7 +93,7 @@ def test_gptj_generation(model_name):
|
|||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
|
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
|
||||||
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
|
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||||
config.fused_bias_fc = True
|
config.fused_bias_fc = True
|
||||||
config.fused_mlp = True
|
config.fused_mlp = True
|
||||||
config.fused_dropout_add_ln = True
|
config.fused_dropout_add_ln = True
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user