diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 8b01257..7b0315b 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -3,6 +3,7 @@ from typing import Tuple, Optional, Union import torch +import torch.nn.functional as F import triton import triton.language as tl @@ -160,6 +161,11 @@ class CrossEntropyLoss(torch.autograd.Function): inplace_backward=False, process_group=None, ): + # For some reason Triton generates wrong code when labels has dtype long and its address + # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index. + if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: + labels = F.pad(labels, (0, 1))[..., :-1] + assert labels.data_ptr() % 16 == 0 n_rows, n_cols = logits.shape assert labels.shape == (n_rows,) world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)