diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index c96fd3b..fcc3e20 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -314,8 +314,8 @@ def _layer_norm_fwd( assert residual_out.stride(-1) == 1 else: residual_out = None - mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: seeds = torch.randint( 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64