flash-attention/flash_attn/modules/mlp.py
2022-11-13 22:06:44 -08:00

73 lines
3.2 KiB
Python

# Copyright (c) 2022, Tri Dao.
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn.ops.fused_dense import fused_dense_gelu_dense_function_td
from flash_attn.ops.fused_dense import fused_dense_res_gelu_dense_function_td
except ImportError:
fused_dense_gelu_dense_function_td = None
fused_dense_res_gelu_dense_function_td = None
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class FusedDenseGeluDense(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
checkpoint_lvl=0, heuristic=0, return_residual=False, device=None, dtype=None):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert bias == True, "DenseGeluDense module without bias is currently not supported"
assert (fused_dense_gelu_dense_function_td is not None
and fused_dense_res_gelu_dense_function_td is not None), 'fused_dense_lib is not installed'
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
assert x.dtype in [torch.float16, torch.bfloat16]
assert x.is_cuda
fn = (fused_dense_gelu_dense_function_td if not self.return_residual
else fused_dense_res_gelu_dense_function_td)
return fn(x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, self.heuristic)