[LayerNorm] Don't store x + residual if we don't need gradients
This commit is contained in:
parent
16025d8cc9
commit
800401847e
@ -307,7 +307,14 @@ def _layer_norm_fwd(
|
||||
or dropout_p > 0.0
|
||||
or rowscale is not None
|
||||
or x1 is not None
|
||||
):
|
||||
) and (
|
||||
x.requires_grad
|
||||
or weight.requires_grad
|
||||
or (bias is not None and bias.requires_grad)
|
||||
or (residual is not None and residual.requires_grad)
|
||||
or (x1 is not None and x1.requires_grad)
|
||||
or (weight1 is not None and weight1.requires_grad)
|
||||
or (bias1 is not None and bias1.requires_grad)):
|
||||
residual_out = torch.empty(
|
||||
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user