From bdcae547c775f95c4aa890f7d69c49f3c5e51983 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 22 Jan 2024 22:40:06 -0800 Subject: [PATCH] [LayerNorm] Don't exit early in the backward pass (fix #781) --- flash_attn/ops/triton/layer_norm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index fcc3e20..c922906 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -452,8 +452,7 @@ def _layer_norm_bwd_kernel( # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program - if row_start >= M: - return + # Do not early exit if row_start >= M, because we need to write DW and DB cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row