[CrossEntropy] Change ignored_index -> ignore_index
This commit is contained in:
parent
85881f547f
commit
ec6d22143b
@ -20,7 +20,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
||||
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
|
||||
label_smoothing: float
|
||||
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||
This is also referred to as "z-loss".
|
||||
@ -60,7 +60,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
label_smoothing=self.label_smoothing,
|
||||
logit_scale=self.logit_scale,
|
||||
lse_square_scale=self.lse_square_scale,
|
||||
ignored_index=self.ignore_index,
|
||||
ignore_index=self.ignore_index,
|
||||
inplace_backward=self.inplace_backward,
|
||||
process_group=self.process_group,
|
||||
)
|
||||
|
||||
@ -32,7 +32,7 @@ def cross_entropy_fwd_kernel(
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
ignore_index,
|
||||
total_classes,
|
||||
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||
n_cols, # shapes
|
||||
@ -56,7 +56,7 @@ def cross_entropy_fwd_kernel(
|
||||
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
|
||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
||||
if label_idx == ignored_index:
|
||||
if label_idx == ignore_index:
|
||||
loss = 0.0
|
||||
z_loss = 0.0
|
||||
else:
|
||||
@ -104,7 +104,7 @@ def cross_entropy_bwd_kernel(
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
ignore_index,
|
||||
total_classes,
|
||||
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||
n_cols, # shapes
|
||||
@ -120,7 +120,7 @@ def cross_entropy_bwd_kernel(
|
||||
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
|
||||
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
label_idx = tl.load(labels_ptr + row_idx)
|
||||
if label_idx != ignored_index:
|
||||
if label_idx != ignore_index:
|
||||
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
|
||||
else:
|
||||
dloss = 0.0
|
||||
@ -150,7 +150,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
smoothing=0.0,
|
||||
logit_scale=1.0,
|
||||
lse_square_scale=0.0,
|
||||
ignored_index=-100,
|
||||
ignore_index=-100,
|
||||
inplace_backward=False,
|
||||
process_group=None,
|
||||
):
|
||||
@ -192,7 +192,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
ignore_index,
|
||||
total_classes,
|
||||
class_start_idx,
|
||||
n_cols, # shapes
|
||||
@ -229,18 +229,18 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
losses += lse
|
||||
if lse_square_scale != 0.0:
|
||||
z_losses = lse_square_scale * lse.square()
|
||||
z_losses.masked_fill_(labels == ignored_index, 0.0)
|
||||
z_losses.masked_fill_(labels == ignore_index, 0.0)
|
||||
losses += z_losses
|
||||
else:
|
||||
z_losses = torch.zeros_like(losses)
|
||||
losses.masked_fill_(labels == ignored_index, 0.0)
|
||||
losses.masked_fill_(labels == ignore_index, 0.0)
|
||||
|
||||
ctx.save_for_backward(logits, lse, labels)
|
||||
ctx.mark_non_differentiable(z_losses)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.logit_scale = logit_scale
|
||||
ctx.lse_square_scale = lse_square_scale
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.ignore_index = ignore_index
|
||||
ctx.total_classes = total_classes
|
||||
ctx.class_start_idx = class_start_idx
|
||||
ctx.inplace_backward = inplace_backward
|
||||
@ -269,7 +269,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
||||
ctx.smoothing,
|
||||
ctx.logit_scale,
|
||||
ctx.lse_square_scale,
|
||||
ctx.ignored_index,
|
||||
ctx.ignore_index,
|
||||
ctx.total_classes,
|
||||
ctx.class_start_idx,
|
||||
n_cols, # shapes
|
||||
@ -287,7 +287,7 @@ def cross_entropy_loss(
|
||||
label_smoothing: float = 0.0,
|
||||
logit_scale: float = 1.0,
|
||||
lse_square_scale: float = 0.0,
|
||||
ignored_index=-100,
|
||||
ignore_index=-100,
|
||||
inplace_backward: bool = False,
|
||||
process_group=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -299,7 +299,7 @@ def cross_entropy_loss(
|
||||
logit_scale: float. Multiply logits by this scale before calculating the loss.
|
||||
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||
This is also referred to as "z-loss".
|
||||
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
||||
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
|
||||
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
||||
This saves memory.
|
||||
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
||||
@ -314,7 +314,7 @@ def cross_entropy_loss(
|
||||
label_smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
ignore_index,
|
||||
inplace_backward,
|
||||
process_group,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user