diff --git a/flash_attn/ops/triton/layernorm.py b/flash_attn/ops/triton/layernorm.py index 8df9d04..63fc165 100644 --- a/flash_attn/ops/triton/layernorm.py +++ b/flash_attn/ops/triton/layernorm.py @@ -1,5 +1,5 @@ # Copyright (c) 2023, Tri Dao. -# Implement residual + layer_norm / rms_norm. +# Implement dropout + residual + layer_norm / rms_norm. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. @@ -16,7 +16,17 @@ import triton import triton.language as tl -def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): +def layer_norm_ref( + x, + weight, + bias, + residual=None, + eps=1e-6, + dropout_p=0.0, + prenorm=False, + dropout_mask=None, + upcast=False, +): dtype = x.dtype if upcast: weight = weight.float() @@ -24,6 +34,11 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upca if upcast: x = x.float() residual = residual.float() if residual is not None else residual + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) if residual is not None: x = (x + residual).to(x.dtype) out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( @@ -32,7 +47,17 @@ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upca return out if not prenorm else (out, x) -def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): +def rms_norm_ref( + x, + weight, + bias, + residual=None, + eps=1e-6, + dropout_p=0.0, + prenorm=False, + dropout_mask=None, + upcast=False, +): dtype = x.dtype if upcast: weight = weight.float() @@ -40,6 +65,11 @@ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast if upcast: x = x.float() residual = residual.float() if residual is not None else residual + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) if residual is not None: x = (x + residual).to(x.dtype) rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) @@ -69,6 +99,8 @@ def _layer_norm_fwd_1pass_kernel( B, # pointer to the biases RESIDUAL, # pointer to the residual RESIDUAL_OUT, # pointer to the residual + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row @@ -77,11 +109,14 @@ def _layer_norm_fwd_1pass_kernel( stride_res_out_row, N, # number of columns in X eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) @@ -94,6 +129,13 @@ def _layer_norm_fwd_1pass_kernel( # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) x += residual @@ -121,7 +163,16 @@ def _layer_norm_fwd_1pass_kernel( def _layer_norm_fwd( - x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False + x, + weight, + bias, + eps, + residual=None, + dropout_p=0.0, + out_dtype=None, + residual_dtype=None, + is_rms_norm=False, + return_dropout_mask=False, ): if residual is not None: residual_dtype = residual.dtype @@ -138,13 +189,27 @@ def _layer_norm_fwd( # allocate output y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) assert y.stride(-1) == 1 - if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): - residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + ): + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) assert residual_out.stride(-1) == 1 else: residual_out = None mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + if dropout_p > 0.0: + seeds = torch.randint(2**32, (M,), device=x.device, dtype=torch.int64) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty_like(x, dtype=torch.bool) + else: + dropout_mask = None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) @@ -159,6 +224,8 @@ def _layer_norm_fwd( bias, residual, residual_out, + seeds, + dropout_mask, mean, rstd, x.stride(0), @@ -167,14 +234,17 @@ def _layer_norm_fwd( residual_out.stride(0) if residual_out is not None else 0, N, eps, + dropout_p, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, bias is not None, + dropout_p > 0.0, + dropout_mask is not None, ) - # residual_out is None if residual is None and residual_dtype == input_dtype - return y, mean, rstd, residual_out if residual_out is not None else x + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask @triton.autotune( @@ -186,7 +256,7 @@ def _layer_norm_fwd( triton.Config({}, num_warps=16), triton.Config({}, num_warps=32), ], - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) @@ -204,6 +274,7 @@ def _layer_norm_bwd_kernel( DB, # pointer to the partial sum of biases gradient DRESIDUAL, DRESIDUAL_IN, + SEEDS, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row @@ -215,12 +286,14 @@ def _layer_norm_bwd_kernel( M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero + dropout_p, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, ): # Map the program id to the elements of X, DX, and DY it should compute. @@ -274,6 +347,9 @@ def _layer_norm_bwd_kernel( # Write dx if STORE_DRESIDUAL: tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) tl.store(DX + cols, dx, mask=mask) X += stride_x_row @@ -299,6 +375,8 @@ def _layer_norm_bwd( mean, rstd, dresidual=None, + seeds=None, + dropout_p=0.0, has_residual=False, is_rms_norm=False, x_dtype=None, @@ -316,13 +394,18 @@ def _layer_norm_bwd( if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M,) # allocate output dx = ( torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) ) - dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + dresidual_in = ( + torch.empty_like(x) if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0) else None + ) y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None # Less than 64KB per feature: enqueue fused kernel @@ -351,6 +434,7 @@ def _layer_norm_bwd( _db, dresidual, dresidual_in, + seeds, mean, rstd, x.stride(0), @@ -362,17 +446,19 @@ def _layer_norm_bwd( M, N, eps, + dropout_p, rows_per_program, is_rms_norm, BLOCK_N, dresidual is not None, dresidual_in is not None, bias is not None, + dropout_p > 0.0, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype: + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0: dresidual_in = dx return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) @@ -386,9 +472,11 @@ class LayerNormFn(torch.autograd.Function): bias, residual=None, eps=1e-6, + dropout_p=0.0, prenorm=False, residual_in_fp32=False, is_rms_norm=False, + return_dropout_mask=False, ): x_shape_og = x.shape # reshape input data into 2D tensor @@ -408,22 +496,36 @@ class LayerNormFn(torch.autograd.Function): if residual is not None else (torch.float32 if residual_in_fp32 else None) ) - y, mean, rstd, residual_out = _layer_norm_fwd( - x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + y, mean, rstd, residual_out, seeds, dropout_mask = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + dropout_p=dropout_p, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, ) - ctx.save_for_backward(residual_out, weight, bias, mean, rstd) + ctx.save_for_backward(residual_out, weight, bias, seeds, mean, rstd) ctx.x_shape_og = x_shape_og ctx.eps = eps + ctx.dropout_p = dropout_p ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype y = y.reshape(x_shape_og) - return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + if not return_dropout_mask: + return y if not prenorm else (y, residual_out) + else: + return (y, dropout_mask) if not prenorm else (y, residual_out, dropout_mask) @staticmethod def backward(ctx, dy, *args): - x, weight, bias, mean, rstd = ctx.saved_tensors + x, weight, bias, seeds, mean, rstd = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() @@ -445,6 +547,8 @@ class LayerNormFn(torch.autograd.Function): mean, rstd, dresidual, + seeds, + ctx.dropout_p, ctx.has_residual, ctx.is_rms_norm, x_dtype=ctx.x_dtype, @@ -458,6 +562,8 @@ class LayerNormFn(torch.autograd.Function): None, None, None, + None, + None, ) @@ -467,22 +573,57 @@ def layer_norm_fn( bias, residual=None, eps=1e-6, + dropout_p=0.0, prenorm=False, residual_in_fp32=False, is_rms_norm=False, + return_dropout_mask=False, ): - return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + return LayerNormFn.apply( + x, + weight, + bias, + residual, + eps, + dropout_p, + prenorm, + residual_in_fp32, + is_rms_norm, + return_dropout_mask, + ) -def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): - return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) +def rms_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + dropout_p=0.0, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + eps, + dropout_p, + prenorm, + residual_in_fp32, + True, + return_dropout_mask, + ) class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps + self.dropout_p = dropout_p self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() @@ -497,9 +638,9 @@ class RMSNorm(torch.nn.Module): self.bias, residual=residual, eps=self.eps, + dropout_p=self.dropout_p if self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, - is_rms_norm=True, ) diff --git a/tests/ops/triton/test_layer_norm.py b/tests/ops/triton/test_layer_norm.py index 0e375d5..839425b 100644 --- a/tests/ops/triton/test_layer_norm.py +++ b/tests/ops/triton/test_layer_norm.py @@ -16,12 +16,14 @@ from flash_attn.ops.triton.layernorm import ( is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 +@pytest.mark.parametrize("dropout_p", [0.0, 0.27]) +# @pytest.mark.parametrize("dropout_p", [0.27]) @pytest.mark.parametrize("prenorm", [True, False]) -# @pytest.mark.parametrize("prenorm", [True]) +# @pytest.mark.parametrize("prenorm", [False]) @pytest.mark.parametrize("is_rms_norm", [False, True]) # @pytest.mark.parametrize("is_rms_norm", [True]) @pytest.mark.parametrize("has_residual", [True, False]) -# @pytest.mark.parametrize("has_residual", [False]) +# @pytest.mark.parametrize("has_residual", [True]) @pytest.mark.parametrize( "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else []) ) @@ -31,11 +33,18 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) -# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)]) -@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 8192]) +# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)]) +@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096]) # @pytest.mark.parametrize("hidden_size", [256]) def test_layer_norm( - hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm + hidden_size, + input_dtype, + residual_dtype, + weight_dtype, + has_residual, + is_rms_norm, + prenorm, + dropout_p, ): device = "cuda" if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]): @@ -48,8 +57,6 @@ def test_layer_norm( torch.random.manual_seed(0) batch_size = 8 seqlen = 512 - # batch_size = 1 - # seqlen = 1 layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref allclose = ( lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max() @@ -83,25 +90,46 @@ def test_layer_norm( bias, residual=res, eps=1e-6, + dropout_p=dropout_p, prenorm=prenorm, residual_in_fp32=residual_in_fp32, is_rms_norm=is_rms_norm, + return_dropout_mask=True, ) - out_pt, *rest_pt = layer_norm_ref_fn( - x0_pt, weight_pt, bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm + dropout_mask = rest[-1] if dropout_p > 0.0 else None + out_pt = layer_norm_ref_fn( + x0_pt, + weight_pt, + bias_pt, + residual=res_pt, + eps=1e-6, + dropout_p=dropout_p, + prenorm=prenorm, + dropout_mask=dropout_mask, ) - out_ref, *rest_ref = layer_norm_ref_fn( - x0_ref, weight_ref, bias_ref, residual=res_ref, eps=1e-6, prenorm=prenorm, upcast=True + out_ref = layer_norm_ref_fn( + x0_ref, + weight_ref, + bias_ref, + residual=res_ref, + eps=1e-6, + dropout_p=dropout_p, + prenorm=prenorm, + dropout_mask=dropout_mask, + upcast=True, ) if prenorm: residual = rest[0] - residual_pt = rest_pt[0] - residual_ref = rest_ref[0] + out_pt, residual_pt = out_pt + out_ref, residual_ref = out_ref assert out.dtype == input_dtype if prenorm: assert residual.dtype == residual_dtype assert allclose(residual, residual_pt, residual_ref) assert allclose(out, out_pt, out_ref) + if dropout_mask is not None: + dropout_fraction = 1.0 - dropout_mask.float().mean() + assert abs(dropout_fraction - dropout_p) < 0.01 g = torch.randn_like(out) / batch_size if not prenorm: @@ -128,9 +156,9 @@ def test_layer_norm( # @pytest.mark.parametrize("has_residual", [False]) @pytest.mark.parametrize("weight_dtype", [torch.float32]) @pytest.mark.parametrize( -"input_dtype,residual_dtype", -[(torch.float16, torch.float16), (torch.float16, torch.float32)] -+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), + "input_dtype,residual_dtype", + [(torch.float16, torch.float16), (torch.float16, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)]) @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])