[LayerNorm] Initialize mean and rstd tensor using x.device
This commit is contained in:
parent
99ea4baa1d
commit
c9861a032d
@ -314,8 +314,8 @@ def _layer_norm_fwd(
|
|||||||
assert residual_out.stride(-1) == 1
|
assert residual_out.stride(-1) == 1
|
||||||
else:
|
else:
|
||||||
residual_out = None
|
residual_out = None
|
||||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
|
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
||||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
||||||
if dropout_p > 0.0:
|
if dropout_p > 0.0:
|
||||||
seeds = torch.randint(
|
seeds = torch.randint(
|
||||||
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user