Revert "[LayerNorm] Don't store x + residual if we don't need gradients"
This reverts commit 800401847e.
This commit is contained in:
parent
800401847e
commit
bd82d6c6eb
@ -307,14 +307,7 @@ def _layer_norm_fwd(
|
|||||||
or dropout_p > 0.0
|
or dropout_p > 0.0
|
||||||
or rowscale is not None
|
or rowscale is not None
|
||||||
or x1 is not None
|
or x1 is not None
|
||||||
) and (
|
):
|
||||||
x.requires_grad
|
|
||||||
or weight.requires_grad
|
|
||||||
or (bias is not None and bias.requires_grad)
|
|
||||||
or (residual is not None and residual.requires_grad)
|
|
||||||
or (x1 is not None and x1.requires_grad)
|
|
||||||
or (weight1 is not None and weight1.requires_grad)
|
|
||||||
or (bias1 is not None and bias1.requires_grad)):
|
|
||||||
residual_out = torch.empty(
|
residual_out = torch.empty(
|
||||||
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user