Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
This commit is contained in:
commit
081c2b012a
@ -93,7 +93,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
if fused_mlp:
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
||||
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
|
||||
if fused_dense_sqrelu_dense:
|
||||
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
||||
@ -123,7 +123,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
if FusedMLP is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
activation = ('gelu_approx' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function)
|
||||
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
@ -80,3 +81,19 @@ class FastGeLUFunction(torch.autograd.Function):
|
||||
return tmp
|
||||
|
||||
fast_gelu_impl = FastGeLUFunction.apply
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def relu_bwd(g, x):
|
||||
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_fwd(x):
|
||||
r = F.relu(x)
|
||||
return (r * r).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_bwd(g, x):
|
||||
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
||||
@ -15,16 +15,11 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
# import fused_dense_cuda # from apex
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
|
||||
from flash_attn.ops.gelu_activation import gelu_bwd
|
||||
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
|
||||
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
|
||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def relu_bwd(g, x):
|
||||
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class FusedDenseFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@ -209,7 +204,9 @@ class FusedMLPFunc(torch.autograd.Function):
|
||||
2: recompute pre_act and gelu_out / relu_out in the bwd
|
||||
"""
|
||||
assert -1 <= heuristic <= 4
|
||||
assert activation in ['gelu_approx', 'relu']
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
if activation == 'sqrelu':
|
||||
assert heuristic == -1
|
||||
if not save_pre_act:
|
||||
checkpoint_lvl = 2
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
@ -248,7 +245,8 @@ class FusedMLPFunc(torch.autograd.Function):
|
||||
if heuristic == -1:
|
||||
pre_act = F.linear(total_x, weight1, bias1)
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else F.relu)
|
||||
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
|
||||
with torch.jit.fuser('fuser2'):
|
||||
output1 = activation_fn(pre_act)
|
||||
# This is before adding bias1
|
||||
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
|
||||
@ -279,7 +277,7 @@ class FusedMLPFunc(torch.autograd.Function):
|
||||
checkpoint_lvl = ctx.checkpoint_lvl
|
||||
activation = ctx.activation
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else F.relu)
|
||||
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
|
||||
if ctx.return_residual:
|
||||
grad_input, = args
|
||||
grad_input = grad_input.contiguous()
|
||||
@ -297,6 +295,7 @@ class FusedMLPFunc(torch.autograd.Function):
|
||||
pre_act, output1 = rest
|
||||
elif checkpoint_lvl == 1:
|
||||
pre_act, = rest
|
||||
with torch.jit.fuser('fuser2'):
|
||||
output1 = activation_fn(pre_act)
|
||||
elif checkpoint_lvl == 2:
|
||||
bias1, = rest
|
||||
@ -304,6 +303,7 @@ class FusedMLPFunc(torch.autograd.Function):
|
||||
total_x, _ = all_gather_raw(x, process_group)
|
||||
if ctx.heuristic == -1:
|
||||
pre_act = F.linear(total_x, weight1, bias1)
|
||||
with torch.jit.fuser('fuser2'):
|
||||
output1 = activation_fn(pre_act)
|
||||
else:
|
||||
output1, pre_act = fused_dense_cuda.linear_act_forward(
|
||||
@ -324,8 +324,9 @@ class FusedMLPFunc(torch.autograd.Function):
|
||||
if ctx.heuristic == -1:
|
||||
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
|
||||
grad_output1 = F.linear(grad_output, weight2.t())
|
||||
activation_grad_fn = (gelu_bwd if activation == 'gelu_approx'
|
||||
else (sqrelu_bwd if activation == 'sqrelu' else relu_bwd))
|
||||
with torch.jit.fuser('fuser2'):
|
||||
activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd
|
||||
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
|
||||
else:
|
||||
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
|
||||
@ -380,7 +381,7 @@ def fused_mlp_func(
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True
|
||||
):
|
||||
assert activation in ['gelu_approx', 'relu']
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
|
||||
@ -428,7 +429,7 @@ class FusedMLP(nn.Module):
|
||||
to fuse the backward of nn.Linear with the residual connection.
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
assert activation in ['gelu_approx', 'relu']
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
@ -436,7 +437,7 @@ class FusedMLP(nn.Module):
|
||||
self.activation = activation
|
||||
self.return_residual = return_residual
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
self.heuristic = heuristic
|
||||
self.heuristic = heuristic if activation != 'sqrelu' else -1
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
||||
|
||||
@ -489,7 +490,7 @@ class ParallelFusedMLP(nn.Module):
|
||||
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
assert activation in ['gelu_approx', 'relu']
|
||||
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||
assert process_group is not None
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
@ -499,7 +500,7 @@ class ParallelFusedMLP(nn.Module):
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
self.heuristic = heuristic
|
||||
self.heuristic = heuristic if activation != 'sqrelu' else -1
|
||||
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
|
||||
bias=bias1, **factory_kwargs)
|
||||
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
|
||||
|
||||
@ -8,17 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
|
||||
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_fwd(x):
|
||||
r = F.relu(x)
|
||||
return (r * r).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_bwd(g, x):
|
||||
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
||||
from flash_attn.ops.activations import sqrelu_fwd, sqrelu_bwd
|
||||
|
||||
|
||||
class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
||||
|
||||
@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
fused_ft_kernel=fused_ft_kernel)
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
break
|
||||
if timing:
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
return output_cls(
|
||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user