Merge branch 'HazyResearch:main' into enable_cuda_graph_capture

This commit is contained in:
Kirthi Shankar Sivamani 2023-04-13 19:36:45 -07:00 committed by GitHub
commit 081c2b012a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 34 deletions

View File

@ -93,7 +93,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
if fused_dense_sqrelu_dense:
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
@ -123,7 +123,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if FusedMLP is None:
raise ImportError('fused_dense is not installed')
activation = ('gelu_approx' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function)
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}

View File

@ -2,7 +2,8 @@
import math
import torch
from torch import nn
import torch.nn as nn
import torch.nn.functional as F
# 1/sqrt(2*pi)-> 0.3989423
@ -80,3 +81,19 @@ class FastGeLUFunction(torch.autograd.Function):
return tmp
fast_gelu_impl = FastGeLUFunction.apply
@torch.jit.script
def relu_bwd(g, x):
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_fwd(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_bwd(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)

View File

@ -15,16 +15,11 @@ from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.gelu_activation import gelu_bwd
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce
@torch.jit.script
def relu_bwd(g, x):
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
class FusedDenseFunc(torch.autograd.Function):
@staticmethod
@ -209,7 +204,9 @@ class FusedMLPFunc(torch.autograd.Function):
2: recompute pre_act and gelu_out / relu_out in the bwd
"""
assert -1 <= heuristic <= 4
assert activation in ['gelu_approx', 'relu']
assert activation in ['gelu_approx', 'relu', 'sqrelu']
if activation == 'sqrelu':
assert heuristic == -1
if not save_pre_act:
checkpoint_lvl = 2
assert checkpoint_lvl in [0, 1, 2]
@ -248,7 +245,8 @@ class FusedMLPFunc(torch.autograd.Function):
if heuristic == -1:
pre_act = F.linear(total_x, weight1, bias1)
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else F.relu)
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
with torch.jit.fuser('fuser2'):
output1 = activation_fn(pre_act)
# This is before adding bias1
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
@ -279,7 +277,7 @@ class FusedMLPFunc(torch.autograd.Function):
checkpoint_lvl = ctx.checkpoint_lvl
activation = ctx.activation
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else F.relu)
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
if ctx.return_residual:
grad_input, = args
grad_input = grad_input.contiguous()
@ -297,6 +295,7 @@ class FusedMLPFunc(torch.autograd.Function):
pre_act, output1 = rest
elif checkpoint_lvl == 1:
pre_act, = rest
with torch.jit.fuser('fuser2'):
output1 = activation_fn(pre_act)
elif checkpoint_lvl == 2:
bias1, = rest
@ -304,6 +303,7 @@ class FusedMLPFunc(torch.autograd.Function):
total_x, _ = all_gather_raw(x, process_group)
if ctx.heuristic == -1:
pre_act = F.linear(total_x, weight1, bias1)
with torch.jit.fuser('fuser2'):
output1 = activation_fn(pre_act)
else:
output1, pre_act = fused_dense_cuda.linear_act_forward(
@ -324,8 +324,9 @@ class FusedMLPFunc(torch.autograd.Function):
if ctx.heuristic == -1:
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
grad_output1 = F.linear(grad_output, weight2.t())
activation_grad_fn = (gelu_bwd if activation == 'gelu_approx'
else (sqrelu_bwd if activation == 'sqrelu' else relu_bwd))
with torch.jit.fuser('fuser2'):
activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
else:
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
@ -380,7 +381,7 @@ def fused_mlp_func(
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True
):
assert activation in ['gelu_approx', 'relu']
assert activation in ['gelu_approx', 'relu', 'sqrelu']
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
@ -428,7 +429,7 @@ class FusedMLP(nn.Module):
to fuse the backward of nn.Linear with the residual connection.
"""
assert checkpoint_lvl in [0, 1, 2]
assert activation in ['gelu_approx', 'relu']
assert activation in ['gelu_approx', 'relu', 'sqrelu']
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if out_features is None:
@ -436,7 +437,7 @@ class FusedMLP(nn.Module):
self.activation = activation
self.return_residual = return_residual
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.heuristic = heuristic if activation != 'sqrelu' else -1
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
@ -489,7 +490,7 @@ class ParallelFusedMLP(nn.Module):
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert checkpoint_lvl in [0, 1, 2]
assert activation in ['gelu_approx', 'relu']
assert activation in ['gelu_approx', 'relu', 'sqrelu']
assert process_group is not None
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
@ -499,7 +500,7 @@ class ParallelFusedMLP(nn.Module):
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.heuristic = heuristic if activation != 'sqrelu' else -1
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
bias=bias1, **factory_kwargs)
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,

View File

@ -8,17 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
@torch.jit.script
def sqrelu_fwd(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_bwd(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
from flash_attn.ops.activations import sqrelu_fwd, sqrelu_bwd
class FusedDenseSqreluDenseFunc(torch.autograd.Function):

View File

@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
fused_ft_kernel=fused_ft_kernel)
scores = []
with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone())
@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if inference_params.sequence_len_offset >= max_length - 1:
break
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),