Revert "[LayerNorm] Don't store x + residual if we don't need gradients"
This reverts commit 800401847e.
This commit is contained in:
parent
800401847e
commit
bd82d6c6eb
@ -307,14 +307,7 @@ def _layer_norm_fwd(
|
||||
or dropout_p > 0.0
|
||||
or rowscale is not None
|
||||
or x1 is not None
|
||||
) and (
|
||||
x.requires_grad
|
||||
or weight.requires_grad
|
||||
or (bias is not None and bias.requires_grad)
|
||||
or (residual is not None and residual.requires_grad)
|
||||
or (x1 is not None and x1.requires_grad)
|
||||
or (weight1 is not None and weight1.requires_grad)
|
||||
or (bias1 is not None and bias1.requires_grad)):
|
||||
):
|
||||
residual_out = torch.empty(
|
||||
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user