[LayerNorm] Initialize mean and rstd tensor using x.device
This commit is contained in:
parent
99ea4baa1d
commit
c9861a032d
@ -314,8 +314,8 @@ def _layer_norm_fwd(
|
||||
assert residual_out.stride(-1) == 1
|
||||
else:
|
||||
residual_out = None
|
||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
||||
if dropout_p > 0.0:
|
||||
seeds = torch.randint(
|
||||
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
||||
|
||||
Loading…
Reference in New Issue
Block a user