diff --git a/flash_attn/ops/triton/layernorm.py b/flash_attn/ops/triton/layernorm.py index 807e959..8df9d04 100644 --- a/flash_attn/ops/triton/layernorm.py +++ b/flash_attn/ops/triton/layernorm.py @@ -10,6 +10,7 @@ import math import torch import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd import triton import triton.language as tl @@ -119,7 +120,9 @@ def _layer_norm_fwd_1pass_kernel( tl.store(Y + cols, y, mask=mask) -def _layer_norm_fwd(x, weight, bias, eps, residual=None, residual_dtype=None, is_rms_norm=False): +def _layer_norm_fwd( + x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): if residual is not None: residual_dtype = residual.dtype M, N = x.shape @@ -133,7 +136,7 @@ def _layer_norm_fwd(x, weight, bias, eps, residual=None, residual_dtype=None, is assert bias.stride(-1) == 1 assert bias.shape == (N,) # allocate output - y = torch.empty_like(x) + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) assert y.stride(-1) == 1 if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) @@ -498,3 +501,136 @@ class RMSNorm(torch.nn.Module): residual_in_fp32=residual_in_fp32, is_rms_norm=True, ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/tests/ops/triton/test_layer_norm.py b/tests/ops/triton/test_layer_norm.py index abacbe9..0e375d5 100644 --- a/tests/ops/triton/test_layer_norm.py +++ b/tests/ops/triton/test_layer_norm.py @@ -1,11 +1,16 @@ -import math -from functools import partial +# Copyright (c) 2023, Tri Dao. import pytest import torch import torch.nn.functional as F from einops import rearrange, repeat -from flash_attn.ops.triton.layernorm import layer_norm_fn, layer_norm_ref, rms_norm_ref + +from flash_attn.ops.triton.layernorm import ( + layer_norm_fn, + layer_norm_ref, + rms_norm_ref, + layer_norm_linear_fn, +) is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @@ -18,13 +23,13 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("has_residual", [True, False]) # @pytest.mark.parametrize("has_residual", [False]) @pytest.mark.parametrize( -"weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else []) + "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else []) ) # @pytest.mark.parametrize("weight_dtype", [torch.float32]) @pytest.mark.parametrize( -"input_dtype,residual_dtype", -[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] -+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), + "input_dtype,residual_dtype", + [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)]) @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 8192]) @@ -113,3 +118,132 @@ def test_layer_norm( assert allclose(weight.grad, weight_pt.grad, weight_ref.grad) if bias is not None: assert allclose(bias.grad, bias_pt.grad, bias_ref.grad) + + +@pytest.mark.parametrize("prenorm", [True, False]) +# @pytest.mark.parametrize("prenorm", [True]) +@pytest.mark.parametrize("is_rms_norm", [False, True]) +# @pytest.mark.parametrize("is_rms_norm", [True]) +@pytest.mark.parametrize("has_residual", [True, False]) +# @pytest.mark.parametrize("has_residual", [False]) +@pytest.mark.parametrize("weight_dtype", [torch.float32]) +@pytest.mark.parametrize( +"input_dtype,residual_dtype", +[(torch.float16, torch.float16), (torch.float16, torch.float32)] ++ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), +) +# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)]) +@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000]) +# @pytest.mark.parametrize("hidden_size", [256]) +def test_layer_norm_linear( + hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm +): + device = "cuda" + if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]): + atol = 5e-2 + elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]): + atol = 1e-2 + else: + atol = 1e-4 + # set seed + torch.random.manual_seed(0) + batch_size = 4 + seqlen = 512 + # batch_size = 1 + # seqlen = 1 + layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref + allclose = ( + lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max() + <= 2 * (x_pt - x_ref).abs().max() + atol + ) + x0 = torch.randn( + batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True + ) + x0_pt = x0.detach().clone().requires_grad_() + x0_ref = x0.detach().clone().requires_grad_() + if has_residual: + res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) + res_pt = res.detach().clone().requires_grad_() + res_ref = res.detach().clone().requires_grad_() + else: + res, res_pt, res_ref = None, None, None + norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) + if not is_rms_norm: + norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) + else: + norm_bias = None + norm_weight_pt = norm_weight.detach().clone().requires_grad_() + norm_weight_ref = norm_weight.detach().clone().requires_grad_() + norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None + norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None + linear_weight = torch.empty( + 2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True + ) + torch.nn.init.xavier_uniform_(linear_weight) + if not is_rms_norm: + linear_bias = torch.randn( + 2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True + ) + else: + linear_bias = None + linear_weight_pt = linear_weight.detach().clone().requires_grad_() + linear_weight_ref = linear_weight.detach().clone().requires_grad_() + linear_bias_pt = ( + linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None + ) + linear_bias_ref = ( + linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None + ) + + residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 + with torch.autocast(device_type="cuda", dtype=input_dtype): + out, *rest = layer_norm_linear_fn( + x0, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=res, + eps=1e-6, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=is_rms_norm, + ) + out_pt, *rest_pt = layer_norm_ref_fn( + x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm + ) + with torch.autocast(device_type="cuda", dtype=input_dtype): + out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt) + out_ref, *rest_ref = layer_norm_ref_fn( + x0_ref, + norm_weight_ref, + norm_bias_ref, + residual=res_ref, + eps=1e-6, + prenorm=prenorm, + upcast=True, + ) + out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref) + if prenorm: + residual = rest[0] + residual_pt = rest_pt[0] + residual_ref = rest_ref[0] + assert out.dtype == input_dtype + if prenorm: + assert residual.dtype == residual_dtype + assert allclose(residual, residual_pt, residual_ref) + assert allclose(out, out_pt, out_ref) + + g = torch.randn_like(out) / batch_size + out.backward(g) + out_pt.backward(g) + out_ref.backward(g) + assert allclose(x0.grad, x0_pt.grad, x0_ref.grad) + if has_residual: + assert allclose(res.grad, res_pt.grad, res_ref.grad) + assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad) + if norm_bias is not None: + assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad) + assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad) + if linear_bias is not None: + assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)