[CrossEntropy] Fix where labels address not aligned to 16 bytes
This commit is contained in:
parent
53a4f34163
commit
bedf877467
@ -3,6 +3,7 @@
|
||||
from typing import Tuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@ -160,6 +161,11 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
inplace_backward=False,
|
||||
process_group=None,
|
||||
):
|
||||
# For some reason Triton generates wrong code when labels has dtype long and its address
|
||||
# is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
|
||||
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
|
||||
labels = F.pad(labels, (0, 1))[..., :-1]
|
||||
assert labels.data_ptr() % 16 == 0
|
||||
n_rows, n_cols = logits.shape
|
||||
assert labels.shape == (n_rows,)
|
||||
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user