[LayerNorm] Don't store x + residual if we don't need gradients
This commit is contained in:
parent
16025d8cc9
commit
800401847e
@ -307,7 +307,14 @@ def _layer_norm_fwd(
|
|||||||
or dropout_p > 0.0
|
or dropout_p > 0.0
|
||||||
or rowscale is not None
|
or rowscale is not None
|
||||||
or x1 is not None
|
or x1 is not None
|
||||||
):
|
) and (
|
||||||
|
x.requires_grad
|
||||||
|
or weight.requires_grad
|
||||||
|
or (bias is not None and bias.requires_grad)
|
||||||
|
or (residual is not None and residual.requires_grad)
|
||||||
|
or (x1 is not None and x1.requires_grad)
|
||||||
|
or (weight1 is not None and weight1.requires_grad)
|
||||||
|
or (bias1 is not None and bias1.requires_grad)):
|
||||||
residual_out = torch.empty(
|
residual_out = torch.empty(
|
||||||
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user