From c9861a032d6a7d044eea33b09b62f9a5d3ae7266 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 9 Jan 2024 16:28:51 -0800 Subject: [PATCH] [LayerNorm] Initialize mean and rstd tensor using x.device --- flash_attn/ops/triton/layer_norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index c96fd3b..fcc3e20 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -314,8 +314,8 @@ def _layer_norm_fwd( assert residual_out.stride(-1) == 1 else: residual_out = None - mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: seeds = torch.randint( 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64