diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 4d17f2d..e8b5d8a 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -43,7 +43,7 @@ def cross_entropy_fwd_kernel( ): row_idx = tl.program_id(0) col_block_idx = tl.program_id(1) - logits_ptr = logits_ptr + row_idx * logits_row_stride + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) label_idx = tl.load(labels_ptr + row_idx) logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( @@ -107,8 +107,8 @@ def cross_entropy_bwd_kernel( ): row_idx = tl.program_id(0) col_block_idx = tl.program_id(1) - logits_ptr = logits_ptr + row_idx * logits_row_stride - dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) label_idx = tl.load(labels_ptr + row_idx) if label_idx != ignored_index: