flash-attention/flash_attn/ops/fused_dense.py

359 lines
16 KiB
Python

# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
# from src.ops.triton.triton_matmul import matmul_dgelu
from flash_attn.ops.gelu_activation import gelu_bwd
# from src.ops.gelu_activation import gelu_bwd, bias_gelu, bias_gelu_back
# implements fused GEMM+bias in forward pass using mlp_cuda from apex
class FusedDenseFuncTD(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias):
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
x = x.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
ctx.save_for_backward(x, weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
return output.reshape(*batch_shape, output.shape[-1])
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
x, weight = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if ctx.needs_input_grad[0]:
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1])
)
grad_input = grad_input.reshape_as(x)
else:
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
x.reshape(batch_dim, n), grad_output.reshape(batch_dim, grad_output.shape[-1])
)
grad_input = None
# print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max())
return grad_input, grad_weight, grad_bias
# grad_input, grad_weight = None, None
# grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1])
# if ctx.needs_input_grad[0]:
# grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n)
# if ctx.needs_input_grad[1]:
# grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n)
# # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch
# # will sum over the batch dimension to get grad_bias.
# return grad_input, grad_weight, grad_output
fused_dense_function_td = FusedDenseFuncTD.apply
class FusedDenseTD(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
if x.is_cuda and self.bias is not None:
return fused_dense_function_td(x, self.weight, self.bias)
else:
return F.linear(x, self.weight, self.bias)
class FusedDenseResidualFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias):
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
x = x.contiguous()
x = x.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
ctx.save_for_backward(x, weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
return output.reshape(*batch_shape, output.shape[-1]), x
@staticmethod
@custom_bwd
def backward(ctx, grad_output, grad_input):
grad_output = grad_output.contiguous()
grad_input = grad_input.contiguous()
x, weight = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_residual_backward(
x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1]),
grad_input.reshape(batch_dim, n)
)
return grad_input.reshape_as(x), grad_weight, grad_bias
fused_dense_residual_function = FusedDenseResidualFunc.apply
class FusedDenseResidual(nn.Linear):
"""Similar to FusedDense, but we return both the output and the input.
This is so that in the backward pass, we can combine the input gradient from the residual branch
with the input gradient from the matrix multiply, without having to do a separate addition.
"""
def forward(self, x):
if x.is_cuda and self.bias is not None:
return fused_dense_residual_function(x, self.weight, self.bias)
else:
return F.linear(x, self.weight, self.bias), x
class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert -1 <= heuristic <= 4
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
for a in [x, weight1, bias1, weight2, bias2]]
assert checkpoint_lvl in [0, 1, 2]
x = x.contiguous()
weight1 = weight1.contiguous()
bias1 = bias1.contiguous()
weight2 = weight2.contiguous()
bias2 = bias2.contiguous()
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
if heuristic == -1:
gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
else:
save_gelu_in = checkpoint_lvl != 2
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
bias1, save_gelu_in, heuristic)
if save_gelu_in:
gelu_in = rest[0]
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, bias1, weight2)
return output2.reshape(*batch_shape, output2.shape[-1])
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
x, weight1, bias1, weight2, *rest = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if checkpoint_lvl == 0:
gelu_in, output1 = rest
elif checkpoint_lvl == 1:
gelu_in, = rest
output1 = F.gelu(gelu_in, approximate='tanh')
elif checkpoint_lvl == 2:
# bias1, = rest
if ctx.heuristic == -1:
gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
else:
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
weight1, bias1, True, ctx.heuristic)
if ctx.heuristic == -1:
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
grad_output1 = grad_output @ weight2
with torch.jit.fuser('fuser2'):
grad_gelu = gelu_bwd(grad_output1, gelu_in)
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight1, grad_gelu
)
# with torch.jit.fuser('fuser2'):
# grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1)
# grad_input = grad_gelu @ weight1
# grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
else:
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(
x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
grad_output.reshape(batch_dim, grad_output.shape[-1]),
ctx.heuristic
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
fused_dense_gelu_dense_function_td = FusedDenseGeluDenseFuncTD.apply
class FusedDenseGeluDenseTD(nn.Module):
def __init__(self, in_features, intermediate_features, out_features=None, bias=True,
checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if out_features is None:
out_features = in_features
assert bias == True, "DenseGeluDense module without bias is currently not supported"
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.fc1 = nn.Linear(in_features, intermediate_features, bias=bias, **factory_kwargs)
self.fc2 = nn.Linear(intermediate_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
return fused_dense_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, self.heuristic)
class FusedDenseResGeluDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert -1 <= heuristic <= 4
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
for a in [x, weight1, bias1, weight2, bias2]]
assert checkpoint_lvl in [0, 1, 2]
x = x.contiguous()
weight1 = weight1.contiguous()
bias1 = bias1.contiguous()
weight2 = weight2.contiguous()
bias2 = bias2.contiguous()
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
# gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
# output1 = F.gelu(gelu_in, approximate='tanh')
save_gelu_in = checkpoint_lvl != 2
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
bias1, save_gelu_in, heuristic)
if save_gelu_in:
gelu_in = rest[0]
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, weight2, gelu_in)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, weight2, bias1)
return output2.reshape(*batch_shape, output2.shape[-1]), x
@staticmethod
@custom_bwd
def backward(ctx, grad_output, grad_input):
grad_output = grad_output.contiguous()
grad_input = grad_input.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
x, weight1, weight2, *rest = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if checkpoint_lvl == 0:
gelu_in, output1 = rest
elif checkpoint_lvl == 1:
gelu_in, = rest
output1 = F.gelu(gelu_in, approximate='tanh')
elif checkpoint_lvl == 2:
bias1, = rest
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
weight1, bias1, True, ctx.heuristic)
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_residual_gelu_linear_backward(
x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
grad_output.reshape(batch_dim, grad_output.shape[-1]),
grad_input.reshape(batch_dim, n),
ctx.heuristic
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu,
# grad_input.reshape(batch_dim, n)
# )
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
fused_dense_res_gelu_dense_function_td = FusedDenseResGeluDenseFunc.apply
class FusedDenseResGeluDense(FusedDenseGeluDenseTD):
def forward(self, x):
return fused_dense_res_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, False, self.heuristic)