diff --git a/flash_attn/ops/triton/layernorm.py b/flash_attn/ops/triton/layernorm.py new file mode 100644 index 0000000..b7bde5a --- /dev/null +++ b/flash_attn/ops/triton/layernorm.py @@ -0,0 +1,395 @@ +# Copyright (c) 2023, Tri Dao. +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(dtype) + return out if residual is None else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if residual is None else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.).to(tl.float32) + x += residual + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd(x, weight, bias, eps, residual=None, is_rms_norm=False): + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x) + assert y.stride(-1) == 1 + if residual is not None: + residual_out = torch.empty_like(residual) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M, ), dtype=torch.float32, device='cuda') if not is_rms_norm else None + rstd = torch.empty((M, ), dtype=torch.float32, device='cuda') + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)](x, y, weight, bias, residual, residual_out, + mean, rstd, + x.stride(0), y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual is not None else 0, + N, eps, + is_rms_norm, + BLOCK_N, + residual is not None, + bias is not None, + ) + return y, mean, rstd, residual_out + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, dresidual=None, is_rms_norm=False, x_dtype=None, + recompute_output=False): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) + dresidual_in = torch.empty_like(dresidual) if dresidual is not None and dx.dtype != dresidual.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid](x, weight, bias, y, + dy, dx, _dw, _db, dresidual, dresidual_in, + mean, rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, N, eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if dresidual is not None and dx.dtype == dresidual.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, residual=None, eps=1e-6, is_rms_norm=False): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd, *rest = _layer_norm_fwd(x, weight, bias, eps, residual, is_rms_norm) + if residual is not None: + residual_out = rest[0] + ctx.save_for_backward(x if residual is None else residual_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if residual is None else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.has_residual: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, dresidual, + ctx.is_rms_norm, x_dtype=ctx.x_dtype) + return dx.reshape(ctx.x_shape_og), dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, None, None + + +def layer_norm_fn(x, weight, bias, residual=None, eps=1e-6, is_rms_norm=False): + return LayerNormFn.apply(x, weight, bias, residual, eps, is_rms_norm) + + +def rms_norm_fn(x, weight, bias, residual=None, eps=1e-6): + return LayerNormFn.apply(x, weight, bias, residual, eps, True) + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, residual=None): + return layer_norm_fn(x, self.weight, self.bias, residual=residual, eps=self.eps, is_rms_norm=True) diff --git a/flash_attn/utils/benchmark.py b/flash_attn/utils/benchmark.py index e691b75..15b3040 100644 --- a/flash_attn/utils/benchmark.py +++ b/flash_attn/utils/benchmark.py @@ -213,7 +213,10 @@ def pytorch_profiler( """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" if backward: with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - g = torch.randn_like(fn(*inputs, **kwinputs)) + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + g = torch.randn_like(out) for _ in range(30): # Warm up if backward: for x in inputs: @@ -221,6 +224,8 @@ def pytorch_profiler( x.grad = None with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] # Backward should be done outside autocast if backward: out.backward(g, retain_graph=True) @@ -239,6 +244,8 @@ def pytorch_profiler( x.grad = None with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] if backward: out.backward(g, retain_graph=True) if verbose: diff --git a/tests/ops/triton/test_layer_norm.py b/tests/ops/triton/test_layer_norm.py new file mode 100644 index 0000000..f51352e --- /dev/null +++ b/tests/ops/triton/test_layer_norm.py @@ -0,0 +1,103 @@ +import math +from functools import partial + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from flash_attn.ops.triton.layernorm import layer_norm_fn, layer_norm_ref, rms_norm_ref + + +is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 + + +@pytest.mark.parametrize("is_rms_norm", [False, True]) +# @pytest.mark.parametrize("is_rms_norm", [True]) +@pytest.mark.parametrize("has_residual", [True, False]) +# @pytest.mark.parametrize("has_residual", [True]) +@pytest.mark.parametrize( + "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else []) +) +# @pytest.mark.parametrize("weight_dtype", [torch.float32]) +@pytest.mark.parametrize( + "input_dtype,residual_dtype", + [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), +) +# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)]) +@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 8192]) +# @pytest.mark.parametrize("hidden_size", [256]) +def test_layer_norm( + hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm +): + device = "cuda" + if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]): + atol = 5e-2 + elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]): + atol = 5e-3 + else: + atol = 1e-4 + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + # batch_size = 1 + # seqlen = 1 + layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref + allclose = ( + lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max() + <= 2 * (x_pt - x_ref).abs().max() + atol + ) + x0 = torch.randn( + batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True + ) + x0_pt = x0.detach().clone().requires_grad_() + x0_ref = x0.detach().clone().requires_grad_() + if has_residual: + res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) + res_pt = res.detach().clone().requires_grad_() + res_ref = res.detach().clone().requires_grad_() + else: + res, res_pt, res_ref = None, None, None + weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) + if not is_rms_norm: + bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) + else: + bias = None + weight_pt = weight.detach().clone().requires_grad_() + weight_ref = weight.detach().clone().requires_grad_() + bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None + bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None + residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 + + out, *rest = layer_norm_fn(x0, weight, bias, residual=res, eps=1e-6, is_rms_norm=is_rms_norm) + out_pt, *rest_pt = layer_norm_ref_fn(x0_pt, weight_pt, bias_pt, residual=res_pt, eps=1e-6) + out_ref, *rest_ref = layer_norm_ref_fn( + x0_ref, weight_ref, bias_ref, residual=res_ref, eps=1e-6, upcast=True + ) + if has_residual: + residual = rest[0] + residual_pt = rest_pt[0] + residual_ref = rest_ref[0] + residual_ref = x0_ref + res_ref + assert out.dtype == input_dtype + if has_residual: + assert residual.dtype == residual_dtype + assert allclose(residual, residual_pt, residual_ref) + assert allclose(out, out_pt, out_ref) + + g = torch.randn_like(out) / batch_size + if not has_residual: + out.backward(g) + out_pt.backward(g) + out_ref.backward(g) + else: + (out * F.sigmoid(residual)).backward(g) + (out_pt * F.sigmoid(residual_pt)).backward(g) + (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g) + assert allclose(x0.grad, x0_pt.grad, x0_ref.grad) + if has_residual: + assert allclose(res.grad, res_pt.grad, res_ref.grad) + assert allclose(weight.grad, weight_pt.grad, weight_ref.grad) + if bias is not None: + assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)