diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index fcc3e20..c922906 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -452,8 +452,7 @@ def _layer_norm_bwd_kernel( # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program - if row_start >= M: - return + # Do not early exit if row_start >= M, because we need to write DW and DB cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row