[LayerNorm] Implement dropout in fused residual + LN/RMSNorm
This commit is contained in:
parent
713bd3aa9a
commit
cd089597fd
@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Implement residual + layer_norm / rms_norm.
|
||||
# Implement dropout + residual + layer_norm / rms_norm.
|
||||
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
@ -16,7 +16,17 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
||||
def layer_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
dropout_p=0.0,
|
||||
prenorm=False,
|
||||
dropout_mask=None,
|
||||
upcast=False,
|
||||
):
|
||||
dtype = x.dtype
|
||||
if upcast:
|
||||
weight = weight.float()
|
||||
@ -24,6 +34,11 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upca
|
||||
if upcast:
|
||||
x = x.float()
|
||||
residual = residual.float() if residual is not None else residual
|
||||
if dropout_p > 0.0:
|
||||
if dropout_mask is not None:
|
||||
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
||||
else:
|
||||
x = F.dropout(x, p=dropout_p)
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
||||
@ -32,7 +47,17 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upca
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
||||
def rms_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
dropout_p=0.0,
|
||||
prenorm=False,
|
||||
dropout_mask=None,
|
||||
upcast=False,
|
||||
):
|
||||
dtype = x.dtype
|
||||
if upcast:
|
||||
weight = weight.float()
|
||||
@ -40,6 +65,11 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast
|
||||
if upcast:
|
||||
x = x.float()
|
||||
residual = residual.float() if residual is not None else residual
|
||||
if dropout_p > 0.0:
|
||||
if dropout_mask is not None:
|
||||
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
||||
else:
|
||||
x = F.dropout(x, p=dropout_p)
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
@ -69,6 +99,8 @@ def _layer_norm_fwd_1pass_kernel(
|
||||
B, # pointer to the biases
|
||||
RESIDUAL, # pointer to the residual
|
||||
RESIDUAL_OUT, # pointer to the residual
|
||||
SEEDS, # Dropout seeds for each row
|
||||
DROPOUT_MASK,
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
@ -77,11 +109,14 @@ def _layer_norm_fwd_1pass_kernel(
|
||||
stride_res_out_row,
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
dropout_p, # Dropout probability
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_RESIDUAL: tl.constexpr,
|
||||
STORE_RESIDUAL_OUT: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_DROPOUT: tl.constexpr,
|
||||
STORE_DROPOUT_MASK: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
@ -94,6 +129,13 @@ def _layer_norm_fwd_1pass_kernel(
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_DROPOUT:
|
||||
# Compute dropout mask
|
||||
# 7 rounds is good enough, and reduces register pressure
|
||||
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
||||
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
||||
if STORE_DROPOUT_MASK:
|
||||
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
||||
if HAS_RESIDUAL:
|
||||
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
x += residual
|
||||
@ -121,7 +163,16 @@ def _layer_norm_fwd_1pass_kernel(
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
residual=None,
|
||||
dropout_p=0.0,
|
||||
out_dtype=None,
|
||||
residual_dtype=None,
|
||||
is_rms_norm=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
if residual is not None:
|
||||
residual_dtype = residual.dtype
|
||||
@ -138,13 +189,27 @@ def _layer_norm_fwd(
|
||||
# allocate output
|
||||
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
||||
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
||||
if (
|
||||
residual is not None
|
||||
or (residual_dtype is not None and residual_dtype != x.dtype)
|
||||
or dropout_p > 0.0
|
||||
):
|
||||
residual_out = torch.empty(
|
||||
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
||||
)
|
||||
assert residual_out.stride(-1) == 1
|
||||
else:
|
||||
residual_out = None
|
||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
if dropout_p > 0.0:
|
||||
seeds = torch.randint(2**32, (M,), device=x.device, dtype=torch.int64)
|
||||
else:
|
||||
seeds = None
|
||||
if return_dropout_mask and dropout_p > 0.0:
|
||||
dropout_mask = torch.empty_like(x, dtype=torch.bool)
|
||||
else:
|
||||
dropout_mask = None
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
@ -159,6 +224,8 @@ def _layer_norm_fwd(
|
||||
bias,
|
||||
residual,
|
||||
residual_out,
|
||||
seeds,
|
||||
dropout_mask,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
@ -167,14 +234,17 @@ def _layer_norm_fwd(
|
||||
residual_out.stride(0) if residual_out is not None else 0,
|
||||
N,
|
||||
eps,
|
||||
dropout_p,
|
||||
is_rms_norm,
|
||||
BLOCK_N,
|
||||
residual is not None,
|
||||
residual_out is not None,
|
||||
bias is not None,
|
||||
dropout_p > 0.0,
|
||||
dropout_mask is not None,
|
||||
)
|
||||
# residual_out is None if residual is None and residual_dtype == input_dtype
|
||||
return y, mean, rstd, residual_out if residual_out is not None else x
|
||||
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
||||
return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@ -186,7 +256,7 @@ def _layer_norm_fwd(
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
||||
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
|
||||
)
|
||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
||||
@ -204,6 +274,7 @@ def _layer_norm_bwd_kernel(
|
||||
DB, # pointer to the partial sum of biases gradient
|
||||
DRESIDUAL,
|
||||
DRESIDUAL_IN,
|
||||
SEEDS,
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
@ -215,12 +286,14 @@ def _layer_norm_bwd_kernel(
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
dropout_p,
|
||||
rows_per_program,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_DRESIDUAL: tl.constexpr,
|
||||
STORE_DRESIDUAL: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_DROPOUT: tl.constexpr,
|
||||
RECOMPUTE_OUTPUT: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the elements of X, DX, and DY it should compute.
|
||||
@ -274,6 +347,9 @@ def _layer_norm_bwd_kernel(
|
||||
# Write dx
|
||||
if STORE_DRESIDUAL:
|
||||
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
||||
if HAS_DROPOUT:
|
||||
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
||||
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
|
||||
X += stride_x_row
|
||||
@ -299,6 +375,8 @@ def _layer_norm_bwd(
|
||||
mean,
|
||||
rstd,
|
||||
dresidual=None,
|
||||
seeds=None,
|
||||
dropout_p=0.0,
|
||||
has_residual=False,
|
||||
is_rms_norm=False,
|
||||
x_dtype=None,
|
||||
@ -316,13 +394,18 @@ def _layer_norm_bwd(
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
if seeds is not None:
|
||||
assert seeds.is_contiguous()
|
||||
assert seeds.shape == (M,)
|
||||
# allocate output
|
||||
dx = (
|
||||
torch.empty_like(x)
|
||||
if x_dtype is None
|
||||
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
||||
)
|
||||
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
|
||||
dresidual_in = (
|
||||
torch.empty_like(x) if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0) else None
|
||||
)
|
||||
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
||||
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
@ -351,6 +434,7 @@ def _layer_norm_bwd(
|
||||
_db,
|
||||
dresidual,
|
||||
dresidual_in,
|
||||
seeds,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
@ -362,17 +446,19 @@ def _layer_norm_bwd(
|
||||
M,
|
||||
N,
|
||||
eps,
|
||||
dropout_p,
|
||||
rows_per_program,
|
||||
is_rms_norm,
|
||||
BLOCK_N,
|
||||
dresidual is not None,
|
||||
dresidual_in is not None,
|
||||
bias is not None,
|
||||
dropout_p > 0.0,
|
||||
)
|
||||
dw = _dw.sum(0).to(weight.dtype)
|
||||
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
||||
# Don't need to compute dresidual_in separately in this case
|
||||
if has_residual and dx.dtype == x.dtype:
|
||||
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0:
|
||||
dresidual_in = dx
|
||||
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
|
||||
|
||||
@ -386,9 +472,11 @@ class LayerNormFn(torch.autograd.Function):
|
||||
bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
dropout_p=0.0,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
@ -408,22 +496,36 @@ class LayerNormFn(torch.autograd.Function):
|
||||
if residual is not None
|
||||
else (torch.float32 if residual_in_fp32 else None)
|
||||
)
|
||||
y, mean, rstd, residual_out = _layer_norm_fwd(
|
||||
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
|
||||
y, mean, rstd, residual_out, seeds, dropout_mask = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
residual,
|
||||
dropout_p=dropout_p,
|
||||
residual_dtype=residual_dtype,
|
||||
is_rms_norm=is_rms_norm,
|
||||
return_dropout_mask=return_dropout_mask,
|
||||
)
|
||||
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
|
||||
ctx.save_for_backward(residual_out, weight, bias, seeds, mean, rstd)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.prenorm = prenorm
|
||||
ctx.x_dtype = x.dtype
|
||||
y = y.reshape(x_shape_og)
|
||||
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
||||
residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
|
||||
dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
||||
if not return_dropout_mask:
|
||||
return y if not prenorm else (y, residual_out)
|
||||
else:
|
||||
return (y, dropout_mask) if not prenorm else (y, residual_out, dropout_mask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy, *args):
|
||||
x, weight, bias, mean, rstd = ctx.saved_tensors
|
||||
x, weight, bias, seeds, mean, rstd = ctx.saved_tensors
|
||||
dy = dy.reshape(-1, dy.shape[-1])
|
||||
if dy.stride(-1) != 1:
|
||||
dy = dy.contiguous()
|
||||
@ -445,6 +547,8 @@ class LayerNormFn(torch.autograd.Function):
|
||||
mean,
|
||||
rstd,
|
||||
dresidual,
|
||||
seeds,
|
||||
ctx.dropout_p,
|
||||
ctx.has_residual,
|
||||
ctx.is_rms_norm,
|
||||
x_dtype=ctx.x_dtype,
|
||||
@ -458,6 +562,8 @@ class LayerNormFn(torch.autograd.Function):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -467,22 +573,57 @@ def layer_norm_fn(
|
||||
bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
dropout_p=0.0,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
|
||||
return LayerNormFn.apply(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual,
|
||||
eps,
|
||||
dropout_p,
|
||||
prenorm,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
return_dropout_mask,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
|
||||
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
|
||||
def rms_norm_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
dropout_p=0.0,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual,
|
||||
eps,
|
||||
dropout_p,
|
||||
prenorm,
|
||||
residual_in_fp32,
|
||||
True,
|
||||
return_dropout_mask,
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dropout_p = dropout_p
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.reset_parameters()
|
||||
@ -497,9 +638,9 @@ class RMSNorm(torch.nn.Module):
|
||||
self.bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
dropout_p=self.dropout_p if self.training else 0.0,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -16,12 +16,14 @@ from flash_attn.ops.triton.layernorm import (
|
||||
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dropout_p", [0.0, 0.27])
|
||||
# @pytest.mark.parametrize("dropout_p", [0.27])
|
||||
@pytest.mark.parametrize("prenorm", [True, False])
|
||||
# @pytest.mark.parametrize("prenorm", [True])
|
||||
# @pytest.mark.parametrize("prenorm", [False])
|
||||
@pytest.mark.parametrize("is_rms_norm", [False, True])
|
||||
# @pytest.mark.parametrize("is_rms_norm", [True])
|
||||
@pytest.mark.parametrize("has_residual", [True, False])
|
||||
# @pytest.mark.parametrize("has_residual", [False])
|
||||
# @pytest.mark.parametrize("has_residual", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
|
||||
)
|
||||
@ -31,11 +33,18 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
|
||||
)
|
||||
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
|
||||
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 8192])
|
||||
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)])
|
||||
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096])
|
||||
# @pytest.mark.parametrize("hidden_size", [256])
|
||||
def test_layer_norm(
|
||||
hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm
|
||||
hidden_size,
|
||||
input_dtype,
|
||||
residual_dtype,
|
||||
weight_dtype,
|
||||
has_residual,
|
||||
is_rms_norm,
|
||||
prenorm,
|
||||
dropout_p,
|
||||
):
|
||||
device = "cuda"
|
||||
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
|
||||
@ -48,8 +57,6 @@ def test_layer_norm(
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 512
|
||||
# batch_size = 1
|
||||
# seqlen = 1
|
||||
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
|
||||
allclose = (
|
||||
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
|
||||
@ -83,25 +90,46 @@ def test_layer_norm(
|
||||
bias,
|
||||
residual=res,
|
||||
eps=1e-6,
|
||||
dropout_p=dropout_p,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
is_rms_norm=is_rms_norm,
|
||||
return_dropout_mask=True,
|
||||
)
|
||||
out_pt, *rest_pt = layer_norm_ref_fn(
|
||||
x0_pt, weight_pt, bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm
|
||||
dropout_mask = rest[-1] if dropout_p > 0.0 else None
|
||||
out_pt = layer_norm_ref_fn(
|
||||
x0_pt,
|
||||
weight_pt,
|
||||
bias_pt,
|
||||
residual=res_pt,
|
||||
eps=1e-6,
|
||||
dropout_p=dropout_p,
|
||||
prenorm=prenorm,
|
||||
dropout_mask=dropout_mask,
|
||||
)
|
||||
out_ref, *rest_ref = layer_norm_ref_fn(
|
||||
x0_ref, weight_ref, bias_ref, residual=res_ref, eps=1e-6, prenorm=prenorm, upcast=True
|
||||
out_ref = layer_norm_ref_fn(
|
||||
x0_ref,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
residual=res_ref,
|
||||
eps=1e-6,
|
||||
dropout_p=dropout_p,
|
||||
prenorm=prenorm,
|
||||
dropout_mask=dropout_mask,
|
||||
upcast=True,
|
||||
)
|
||||
if prenorm:
|
||||
residual = rest[0]
|
||||
residual_pt = rest_pt[0]
|
||||
residual_ref = rest_ref[0]
|
||||
out_pt, residual_pt = out_pt
|
||||
out_ref, residual_ref = out_ref
|
||||
assert out.dtype == input_dtype
|
||||
if prenorm:
|
||||
assert residual.dtype == residual_dtype
|
||||
assert allclose(residual, residual_pt, residual_ref)
|
||||
assert allclose(out, out_pt, out_ref)
|
||||
if dropout_mask is not None:
|
||||
dropout_fraction = 1.0 - dropout_mask.float().mean()
|
||||
assert abs(dropout_fraction - dropout_p) < 0.01
|
||||
|
||||
g = torch.randn_like(out) / batch_size
|
||||
if not prenorm:
|
||||
@ -128,9 +156,9 @@ def test_layer_norm(
|
||||
# @pytest.mark.parametrize("has_residual", [False])
|
||||
@pytest.mark.parametrize("weight_dtype", [torch.float32])
|
||||
@pytest.mark.parametrize(
|
||||
"input_dtype,residual_dtype",
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
|
||||
"input_dtype,residual_dtype",
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
|
||||
)
|
||||
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
|
||||
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user