[CrossEntropy] Change ignored_index -> ignore_index

This commit is contained in:
Tri Dao 2024-04-26 10:50:41 -07:00
parent 85881f547f
commit ec6d22143b
2 changed files with 15 additions and 15 deletions

View File

@ -20,7 +20,7 @@ class CrossEntropyLoss(nn.Module):
):
"""
Arguments:
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
@ -60,7 +60,7 @@ class CrossEntropyLoss(nn.Module):
label_smoothing=self.label_smoothing,
logit_scale=self.logit_scale,
lse_square_scale=self.lse_square_scale,
ignored_index=self.ignore_index,
ignore_index=self.ignore_index,
inplace_backward=self.inplace_backward,
process_group=self.process_group,
)

View File

@ -32,7 +32,7 @@ def cross_entropy_fwd_kernel(
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
ignore_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
@ -56,7 +56,7 @@ def cross_entropy_fwd_kernel(
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
if label_idx == ignored_index:
if label_idx == ignore_index:
loss = 0.0
z_loss = 0.0
else:
@ -104,7 +104,7 @@ def cross_entropy_bwd_kernel(
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
ignore_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
@ -120,7 +120,7 @@ def cross_entropy_bwd_kernel(
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx != ignored_index:
if label_idx != ignore_index:
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
else:
dloss = 0.0
@ -150,7 +150,7 @@ class CrossEntropyLoss(torch.autograd.Function):
smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
ignored_index=-100,
ignore_index=-100,
inplace_backward=False,
process_group=None,
):
@ -192,7 +192,7 @@ class CrossEntropyLoss(torch.autograd.Function):
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
ignore_index,
total_classes,
class_start_idx,
n_cols, # shapes
@ -229,18 +229,18 @@ class CrossEntropyLoss(torch.autograd.Function):
losses += lse
if lse_square_scale != 0.0:
z_losses = lse_square_scale * lse.square()
z_losses.masked_fill_(labels == ignored_index, 0.0)
z_losses.masked_fill_(labels == ignore_index, 0.0)
losses += z_losses
else:
z_losses = torch.zeros_like(losses)
losses.masked_fill_(labels == ignored_index, 0.0)
losses.masked_fill_(labels == ignore_index, 0.0)
ctx.save_for_backward(logits, lse, labels)
ctx.mark_non_differentiable(z_losses)
ctx.smoothing = smoothing
ctx.logit_scale = logit_scale
ctx.lse_square_scale = lse_square_scale
ctx.ignored_index = ignored_index
ctx.ignore_index = ignore_index
ctx.total_classes = total_classes
ctx.class_start_idx = class_start_idx
ctx.inplace_backward = inplace_backward
@ -269,7 +269,7 @@ class CrossEntropyLoss(torch.autograd.Function):
ctx.smoothing,
ctx.logit_scale,
ctx.lse_square_scale,
ctx.ignored_index,
ctx.ignore_index,
ctx.total_classes,
ctx.class_start_idx,
n_cols, # shapes
@ -287,7 +287,7 @@ def cross_entropy_loss(
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
ignored_index=-100,
ignore_index=-100,
inplace_backward: bool = False,
process_group=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -299,7 +299,7 @@ def cross_entropy_loss(
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
@ -314,7 +314,7 @@ def cross_entropy_loss(
label_smoothing,
logit_scale,
lse_square_scale,
ignored_index,
ignore_index,
inplace_backward,
process_group,
)