diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index 6fcf50e..addffe1 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -267,6 +267,8 @@ def _layer_norm_fwd( residual_dtype=None, is_rms_norm=False, return_dropout_mask=False, + out=None, + residual_out=None ): if residual is not None: residual_dtype = residual.dtype @@ -294,10 +296,13 @@ def _layer_norm_fwd( assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output - y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - assert y.stride(-1) == 1 + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + else: + assert out.shape == x.shape + assert out.stride(-1) == 1 if weight1 is not None: - y1 = torch.empty_like(y) + y1 = torch.empty_like(out) assert y1.stride(-1) == 1 else: y1 = None @@ -308,9 +313,12 @@ def _layer_norm_fwd( or rowscale is not None or x1 is not None ): - residual_out = torch.empty( - M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype - ) + if residual_out is None: + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + assert residual_out.shape == x.shape assert residual_out.stride(-1) == 1 else: residual_out = None @@ -334,7 +342,7 @@ def _layer_norm_fwd( with torch.cuda.device(x.device.index): _layer_norm_fwd_1pass_kernel[(M,)]( x, - y, + out, weight, bias, residual, @@ -349,7 +357,7 @@ def _layer_norm_fwd( mean, rstd, x.stride(0), - y.stride(0), + out.stride(0), residual.stride(0) if residual is not None else 0, residual_out.stride(0) if residual_out is not None else 0, x1.stride(0) if x1 is not None else 0, @@ -373,7 +381,7 @@ def _layer_norm_fwd( else: dropout_mask1 = None return ( - y, + out, y1, mean, rstd, @@ -714,6 +722,8 @@ class LayerNormFn(torch.autograd.Function): residual_in_fp32=False, is_rms_norm=False, return_dropout_mask=False, + out=None, + residual_out=None ): x_shape_og = x.shape # reshape input data into 2D tensor @@ -745,6 +755,10 @@ class LayerNormFn(torch.autograd.Function): if residual is not None else (torch.float32 if residual_in_fp32 else None) ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( x, weight, @@ -759,6 +773,8 @@ class LayerNormFn(torch.autograd.Function): residual_dtype=residual_dtype, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out ) ctx.save_for_backward( residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd @@ -853,6 +869,8 @@ class LayerNormFn(torch.autograd.Function): None, None, None, + None, + None, ) @@ -871,6 +889,8 @@ def layer_norm_fn( residual_in_fp32=False, is_rms_norm=False, return_dropout_mask=False, + out=None, + residual_out=None ): return LayerNormFn.apply( x, @@ -887,6 +907,8 @@ def layer_norm_fn( residual_in_fp32, is_rms_norm, return_dropout_mask, + out, + residual_out ) @@ -904,6 +926,8 @@ def rms_norm_fn( prenorm=False, residual_in_fp32=False, return_dropout_mask=False, + out=None, + residual_out=None ): return LayerNormFn.apply( x, @@ -920,6 +944,8 @@ def rms_norm_fn( residual_in_fp32, True, return_dropout_mask, + out, + residual_out )