Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
This commit is contained in:
commit
081c2b012a
@ -93,7 +93,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||||
if fused_mlp:
|
if fused_mlp:
|
||||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
|
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
||||||
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
|
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
|
||||||
if fused_dense_sqrelu_dense:
|
if fused_dense_sqrelu_dense:
|
||||||
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
||||||
@ -123,7 +123,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
if FusedMLP is None:
|
if FusedMLP is None:
|
||||||
raise ImportError('fused_dense is not installed')
|
raise ImportError('fused_dense is not installed')
|
||||||
activation = ('gelu_approx' if config.activation_function
|
activation = ('gelu_approx' if config.activation_function
|
||||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
|
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function)
|
||||||
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
||||||
parallel_kwargs = ({'process_group': process_group,
|
parallel_kwargs = ({'process_group': process_group,
|
||||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||||
|
|||||||
@ -2,7 +2,8 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
# 1/sqrt(2*pi)-> 0.3989423
|
# 1/sqrt(2*pi)-> 0.3989423
|
||||||
@ -80,3 +81,19 @@ class FastGeLUFunction(torch.autograd.Function):
|
|||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
fast_gelu_impl = FastGeLUFunction.apply
|
fast_gelu_impl = FastGeLUFunction.apply
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def relu_bwd(g, x):
|
||||||
|
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def sqrelu_fwd(x):
|
||||||
|
r = F.relu(x)
|
||||||
|
return (r * r).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def sqrelu_bwd(g, x):
|
||||||
|
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
||||||
@ -15,16 +15,11 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||||||
# import fused_dense_cuda # from apex
|
# import fused_dense_cuda # from apex
|
||||||
import fused_dense_lib as fused_dense_cuda
|
import fused_dense_lib as fused_dense_cuda
|
||||||
|
|
||||||
from flash_attn.ops.gelu_activation import gelu_bwd
|
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
|
||||||
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
|
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
|
||||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def relu_bwd(g, x):
|
|
||||||
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedDenseFunc(torch.autograd.Function):
|
class FusedDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -209,7 +204,9 @@ class FusedMLPFunc(torch.autograd.Function):
|
|||||||
2: recompute pre_act and gelu_out / relu_out in the bwd
|
2: recompute pre_act and gelu_out / relu_out in the bwd
|
||||||
"""
|
"""
|
||||||
assert -1 <= heuristic <= 4
|
assert -1 <= heuristic <= 4
|
||||||
assert activation in ['gelu_approx', 'relu']
|
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||||
|
if activation == 'sqrelu':
|
||||||
|
assert heuristic == -1
|
||||||
if not save_pre_act:
|
if not save_pre_act:
|
||||||
checkpoint_lvl = 2
|
checkpoint_lvl = 2
|
||||||
assert checkpoint_lvl in [0, 1, 2]
|
assert checkpoint_lvl in [0, 1, 2]
|
||||||
@ -248,7 +245,8 @@ class FusedMLPFunc(torch.autograd.Function):
|
|||||||
if heuristic == -1:
|
if heuristic == -1:
|
||||||
pre_act = F.linear(total_x, weight1, bias1)
|
pre_act = F.linear(total_x, weight1, bias1)
|
||||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||||
else F.relu)
|
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
|
||||||
|
with torch.jit.fuser('fuser2'):
|
||||||
output1 = activation_fn(pre_act)
|
output1 = activation_fn(pre_act)
|
||||||
# This is before adding bias1
|
# This is before adding bias1
|
||||||
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
|
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
|
||||||
@ -279,7 +277,7 @@ class FusedMLPFunc(torch.autograd.Function):
|
|||||||
checkpoint_lvl = ctx.checkpoint_lvl
|
checkpoint_lvl = ctx.checkpoint_lvl
|
||||||
activation = ctx.activation
|
activation = ctx.activation
|
||||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||||
else F.relu)
|
else (sqrelu_fwd if activation == 'sqrelu' else F.relu))
|
||||||
if ctx.return_residual:
|
if ctx.return_residual:
|
||||||
grad_input, = args
|
grad_input, = args
|
||||||
grad_input = grad_input.contiguous()
|
grad_input = grad_input.contiguous()
|
||||||
@ -297,6 +295,7 @@ class FusedMLPFunc(torch.autograd.Function):
|
|||||||
pre_act, output1 = rest
|
pre_act, output1 = rest
|
||||||
elif checkpoint_lvl == 1:
|
elif checkpoint_lvl == 1:
|
||||||
pre_act, = rest
|
pre_act, = rest
|
||||||
|
with torch.jit.fuser('fuser2'):
|
||||||
output1 = activation_fn(pre_act)
|
output1 = activation_fn(pre_act)
|
||||||
elif checkpoint_lvl == 2:
|
elif checkpoint_lvl == 2:
|
||||||
bias1, = rest
|
bias1, = rest
|
||||||
@ -304,6 +303,7 @@ class FusedMLPFunc(torch.autograd.Function):
|
|||||||
total_x, _ = all_gather_raw(x, process_group)
|
total_x, _ = all_gather_raw(x, process_group)
|
||||||
if ctx.heuristic == -1:
|
if ctx.heuristic == -1:
|
||||||
pre_act = F.linear(total_x, weight1, bias1)
|
pre_act = F.linear(total_x, weight1, bias1)
|
||||||
|
with torch.jit.fuser('fuser2'):
|
||||||
output1 = activation_fn(pre_act)
|
output1 = activation_fn(pre_act)
|
||||||
else:
|
else:
|
||||||
output1, pre_act = fused_dense_cuda.linear_act_forward(
|
output1, pre_act = fused_dense_cuda.linear_act_forward(
|
||||||
@ -324,8 +324,9 @@ class FusedMLPFunc(torch.autograd.Function):
|
|||||||
if ctx.heuristic == -1:
|
if ctx.heuristic == -1:
|
||||||
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
|
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
|
||||||
grad_output1 = F.linear(grad_output, weight2.t())
|
grad_output1 = F.linear(grad_output, weight2.t())
|
||||||
|
activation_grad_fn = (gelu_bwd if activation == 'gelu_approx'
|
||||||
|
else (sqrelu_bwd if activation == 'sqrelu' else relu_bwd))
|
||||||
with torch.jit.fuser('fuser2'):
|
with torch.jit.fuser('fuser2'):
|
||||||
activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd
|
|
||||||
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
|
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
|
||||||
else:
|
else:
|
||||||
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
|
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
|
||||||
@ -380,7 +381,7 @@ def fused_mlp_func(
|
|||||||
process_group: Optional[ProcessGroup] = None,
|
process_group: Optional[ProcessGroup] = None,
|
||||||
sequence_parallel: bool = True
|
sequence_parallel: bool = True
|
||||||
):
|
):
|
||||||
assert activation in ['gelu_approx', 'relu']
|
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||||
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
|
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
|
||||||
@ -428,7 +429,7 @@ class FusedMLP(nn.Module):
|
|||||||
to fuse the backward of nn.Linear with the residual connection.
|
to fuse the backward of nn.Linear with the residual connection.
|
||||||
"""
|
"""
|
||||||
assert checkpoint_lvl in [0, 1, 2]
|
assert checkpoint_lvl in [0, 1, 2]
|
||||||
assert activation in ['gelu_approx', 'relu']
|
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if out_features is None:
|
if out_features is None:
|
||||||
@ -436,7 +437,7 @@ class FusedMLP(nn.Module):
|
|||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.return_residual = return_residual
|
self.return_residual = return_residual
|
||||||
self.checkpoint_lvl = checkpoint_lvl
|
self.checkpoint_lvl = checkpoint_lvl
|
||||||
self.heuristic = heuristic
|
self.heuristic = heuristic if activation != 'sqrelu' else -1
|
||||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
||||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
||||||
|
|
||||||
@ -489,7 +490,7 @@ class ParallelFusedMLP(nn.Module):
|
|||||||
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||||
"""
|
"""
|
||||||
assert checkpoint_lvl in [0, 1, 2]
|
assert checkpoint_lvl in [0, 1, 2]
|
||||||
assert activation in ['gelu_approx', 'relu']
|
assert activation in ['gelu_approx', 'relu', 'sqrelu']
|
||||||
assert process_group is not None
|
assert process_group is not None
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -499,7 +500,7 @@ class ParallelFusedMLP(nn.Module):
|
|||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.sequence_parallel = sequence_parallel
|
self.sequence_parallel = sequence_parallel
|
||||||
self.checkpoint_lvl = checkpoint_lvl
|
self.checkpoint_lvl = checkpoint_lvl
|
||||||
self.heuristic = heuristic
|
self.heuristic = heuristic if activation != 'sqrelu' else -1
|
||||||
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
|
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
|
||||||
bias=bias1, **factory_kwargs)
|
bias=bias1, **factory_kwargs)
|
||||||
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
|
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
|
||||||
|
|||||||
@ -8,17 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||||||
import fused_dense_lib as fused_dense_cuda
|
import fused_dense_lib as fused_dense_cuda
|
||||||
|
|
||||||
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
|
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
|
||||||
|
from flash_attn.ops.activations import sqrelu_fwd, sqrelu_bwd
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def sqrelu_fwd(x):
|
|
||||||
r = F.relu(x)
|
|
||||||
return (r * r).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def sqrelu_bwd(g, x):
|
|
||||||
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
||||||
|
|||||||
@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
|||||||
fused_ft_kernel=fused_ft_kernel)
|
fused_ft_kernel=fused_ft_kernel)
|
||||||
scores = []
|
scores = []
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
|
||||||
if timing:
|
if timing:
|
||||||
|
if tensor_parallel > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||||
if vocab_size is not None:
|
if vocab_size is not None:
|
||||||
logits = logits[..., :vocab_size]
|
logits = logits[..., :vocab_size]
|
||||||
scores.append(logits if not cg else logits.clone())
|
scores.append(logits if not cg else logits.clone())
|
||||||
@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
|||||||
if inference_params.sequence_len_offset >= max_length - 1:
|
if inference_params.sequence_len_offset >= max_length - 1:
|
||||||
break
|
break
|
||||||
if timing:
|
if timing:
|
||||||
|
if tensor_parallel > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms')
|
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||||
return output_cls(
|
return output_cls(
|
||||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user