[LayerNorm] Don't store x + residual if we don't need gradients

This commit is contained in:
Tri Dao 2024-08-15 11:07:46 -07:00
parent 16025d8cc9
commit 800401847e

View File

@ -307,7 +307,14 @@ def _layer_norm_fwd(
or dropout_p > 0.0
or rowscale is not None
or x1 is not None
):
) and (
x.requires_grad
or weight.requires_grad
or (bias is not None and bias.requires_grad)
or (residual is not None and residual.requires_grad)
or (x1 is not None and x1.requires_grad)
or (weight1 is not None and weight1.requires_grad)
or (bias1 is not None and bias1.requires_grad)):
residual_out = torch.empty(
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
)