[LayerNorm] Don't exit early in the backward pass (fix #781)
This commit is contained in:
parent
36bc29edf7
commit
bdcae547c7
@ -452,8 +452,7 @@ def _layer_norm_bwd_kernel(
|
||||
# Map the program id to the elements of X, DX, and DY it should compute.
|
||||
row_block_id = tl.program_id(0)
|
||||
row_start = row_block_id * rows_per_program
|
||||
if row_start >= M:
|
||||
return
|
||||
# Do not early exit if row_start >= M, because we need to write DW and DB
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
mask = cols < N
|
||||
X += row_start * stride_x_row
|
||||
|
||||
Loading…
Reference in New Issue
Block a user