[CrossEntropy] Fix where labels address not aligned to 16 bytes

This commit is contained in:
Tri Dao 2024-10-05 02:02:24 -07:00
parent 53a4f34163
commit bedf877467

View File

@ -3,6 +3,7 @@
from typing import Tuple, Optional, Union
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
@ -160,6 +161,11 @@ class CrossEntropyLoss(torch.autograd.Function):
inplace_backward=False,
process_group=None,
):
# For some reason Triton generates wrong code when labels has dtype long and its address
# is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
labels = F.pad(labels, (0, 1))[..., :-1]
assert labels.data_ptr() % 16 == 0
n_rows, n_cols = logits.shape
assert labels.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)