[LayerNorm] Implement layer_norm_linear
This commit is contained in:
parent
92dd5703ec
commit
9356a1c038
@ -10,6 +10,7 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@ -119,7 +120,9 @@ def _layer_norm_fwd_1pass_kernel(
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(x, weight, bias, eps, residual=None, residual_dtype=None, is_rms_norm=False):
|
||||
def _layer_norm_fwd(
|
||||
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
||||
):
|
||||
if residual is not None:
|
||||
residual_dtype = residual.dtype
|
||||
M, N = x.shape
|
||||
@ -133,7 +136,7 @@ def _layer_norm_fwd(x, weight, bias, eps, residual=None, residual_dtype=None, is
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
||||
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
||||
@ -498,3 +501,136 @@ class RMSNorm(torch.nn.Module):
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
|
||||
|
||||
class LayerNormLinearFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if residual is not None:
|
||||
assert residual.shape == x_shape_og
|
||||
residual = residual.reshape(-1, residual.shape[-1])
|
||||
if residual.stride(-1) != 1:
|
||||
residual = residual.contiguous()
|
||||
norm_weight = norm_weight.contiguous()
|
||||
if norm_bias is not None:
|
||||
norm_bias = norm_bias.contiguous()
|
||||
residual_dtype = (
|
||||
residual.dtype
|
||||
if residual is not None
|
||||
else (torch.float32 if residual_in_fp32 else None)
|
||||
)
|
||||
y, mean, rstd, residual_out = _layer_norm_fwd(
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
eps,
|
||||
residual,
|
||||
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
|
||||
residual_dtype=residual_dtype,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
y = y.reshape(x_shape_og)
|
||||
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
||||
linear_weight = linear_weight.to(dtype)
|
||||
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
||||
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
||||
# We don't store y, will be recomputed in the backward pass to save memory
|
||||
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.prenorm = prenorm
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.linear_bias_is_none = linear_bias is None
|
||||
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, dout, *args):
|
||||
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
||||
dout = dout.reshape(-1, dout.shape[-1])
|
||||
dy = F.linear(dout, linear_weight.t())
|
||||
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
||||
if dy.stride(-1) != 1:
|
||||
dy = dy.contiguous()
|
||||
assert dy.shape == x.shape
|
||||
if ctx.prenorm:
|
||||
dresidual = args[0]
|
||||
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
||||
if dresidual.stride(-1) != 1:
|
||||
dresidual = dresidual.contiguous()
|
||||
assert dresidual.shape == x.shape
|
||||
else:
|
||||
dresidual = None
|
||||
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
|
||||
dy,
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
ctx.eps,
|
||||
mean,
|
||||
rstd,
|
||||
dresidual,
|
||||
ctx.has_residual,
|
||||
ctx.is_rms_norm,
|
||||
x_dtype=ctx.x_dtype,
|
||||
recompute_output=True,
|
||||
)
|
||||
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
||||
return (
|
||||
dx.reshape(ctx.x_shape_og),
|
||||
dnorm_weight,
|
||||
dnorm_bias,
|
||||
dlinear_weight,
|
||||
dlinear_bias,
|
||||
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def layer_norm_linear_fn(
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
return LayerNormLinearFn.apply(
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual,
|
||||
eps,
|
||||
prenorm,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
)
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import math
|
||||
from functools import partial
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from flash_attn.ops.triton.layernorm import layer_norm_fn, layer_norm_ref, rms_norm_ref
|
||||
|
||||
from flash_attn.ops.triton.layernorm import (
|
||||
layer_norm_fn,
|
||||
layer_norm_ref,
|
||||
rms_norm_ref,
|
||||
layer_norm_linear_fn,
|
||||
)
|
||||
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
@ -18,13 +23,13 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
|
||||
@pytest.mark.parametrize("has_residual", [True, False])
|
||||
# @pytest.mark.parametrize("has_residual", [False])
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
|
||||
"weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
|
||||
)
|
||||
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
|
||||
@pytest.mark.parametrize(
|
||||
"input_dtype,residual_dtype",
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
|
||||
"input_dtype,residual_dtype",
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
|
||||
)
|
||||
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
|
||||
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 8192])
|
||||
@ -113,3 +118,132 @@ def test_layer_norm(
|
||||
assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
|
||||
if bias is not None:
|
||||
assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prenorm", [True, False])
|
||||
# @pytest.mark.parametrize("prenorm", [True])
|
||||
@pytest.mark.parametrize("is_rms_norm", [False, True])
|
||||
# @pytest.mark.parametrize("is_rms_norm", [True])
|
||||
@pytest.mark.parametrize("has_residual", [True, False])
|
||||
# @pytest.mark.parametrize("has_residual", [False])
|
||||
@pytest.mark.parametrize("weight_dtype", [torch.float32])
|
||||
@pytest.mark.parametrize(
|
||||
"input_dtype,residual_dtype",
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
|
||||
)
|
||||
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
|
||||
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])
|
||||
# @pytest.mark.parametrize("hidden_size", [256])
|
||||
def test_layer_norm_linear(
|
||||
hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm
|
||||
):
|
||||
device = "cuda"
|
||||
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
|
||||
atol = 5e-2
|
||||
elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
|
||||
atol = 1e-2
|
||||
else:
|
||||
atol = 1e-4
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 4
|
||||
seqlen = 512
|
||||
# batch_size = 1
|
||||
# seqlen = 1
|
||||
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
|
||||
allclose = (
|
||||
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
|
||||
<= 2 * (x_pt - x_ref).abs().max() + atol
|
||||
)
|
||||
x0 = torch.randn(
|
||||
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
|
||||
)
|
||||
x0_pt = x0.detach().clone().requires_grad_()
|
||||
x0_ref = x0.detach().clone().requires_grad_()
|
||||
if has_residual:
|
||||
res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
|
||||
res_pt = res.detach().clone().requires_grad_()
|
||||
res_ref = res.detach().clone().requires_grad_()
|
||||
else:
|
||||
res, res_pt, res_ref = None, None, None
|
||||
norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
|
||||
if not is_rms_norm:
|
||||
norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
|
||||
else:
|
||||
norm_bias = None
|
||||
norm_weight_pt = norm_weight.detach().clone().requires_grad_()
|
||||
norm_weight_ref = norm_weight.detach().clone().requires_grad_()
|
||||
norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
|
||||
norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
|
||||
linear_weight = torch.empty(
|
||||
2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True
|
||||
)
|
||||
torch.nn.init.xavier_uniform_(linear_weight)
|
||||
if not is_rms_norm:
|
||||
linear_bias = torch.randn(
|
||||
2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True
|
||||
)
|
||||
else:
|
||||
linear_bias = None
|
||||
linear_weight_pt = linear_weight.detach().clone().requires_grad_()
|
||||
linear_weight_ref = linear_weight.detach().clone().requires_grad_()
|
||||
linear_bias_pt = (
|
||||
linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
|
||||
)
|
||||
linear_bias_ref = (
|
||||
linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
|
||||
)
|
||||
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
with torch.autocast(device_type="cuda", dtype=input_dtype):
|
||||
out, *rest = layer_norm_linear_fn(
|
||||
x0,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual=res,
|
||||
eps=1e-6,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
out_pt, *rest_pt = layer_norm_ref_fn(
|
||||
x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm
|
||||
)
|
||||
with torch.autocast(device_type="cuda", dtype=input_dtype):
|
||||
out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt)
|
||||
out_ref, *rest_ref = layer_norm_ref_fn(
|
||||
x0_ref,
|
||||
norm_weight_ref,
|
||||
norm_bias_ref,
|
||||
residual=res_ref,
|
||||
eps=1e-6,
|
||||
prenorm=prenorm,
|
||||
upcast=True,
|
||||
)
|
||||
out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref)
|
||||
if prenorm:
|
||||
residual = rest[0]
|
||||
residual_pt = rest_pt[0]
|
||||
residual_ref = rest_ref[0]
|
||||
assert out.dtype == input_dtype
|
||||
if prenorm:
|
||||
assert residual.dtype == residual_dtype
|
||||
assert allclose(residual, residual_pt, residual_ref)
|
||||
assert allclose(out, out_pt, out_ref)
|
||||
|
||||
g = torch.randn_like(out) / batch_size
|
||||
out.backward(g)
|
||||
out_pt.backward(g)
|
||||
out_ref.backward(g)
|
||||
assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
|
||||
if has_residual:
|
||||
assert allclose(res.grad, res_pt.grad, res_ref.grad)
|
||||
assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad)
|
||||
if norm_bias is not None:
|
||||
assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad)
|
||||
assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad)
|
||||
if linear_bias is not None:
|
||||
assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user