[CrossEntropy] Change ignored_index -> ignore_index
This commit is contained in:
parent
85881f547f
commit
ec6d22143b
@ -20,7 +20,7 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
|
||||||
label_smoothing: float
|
label_smoothing: float
|
||||||
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||||
This is also referred to as "z-loss".
|
This is also referred to as "z-loss".
|
||||||
@ -60,7 +60,7 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
label_smoothing=self.label_smoothing,
|
label_smoothing=self.label_smoothing,
|
||||||
logit_scale=self.logit_scale,
|
logit_scale=self.logit_scale,
|
||||||
lse_square_scale=self.lse_square_scale,
|
lse_square_scale=self.lse_square_scale,
|
||||||
ignored_index=self.ignore_index,
|
ignore_index=self.ignore_index,
|
||||||
inplace_backward=self.inplace_backward,
|
inplace_backward=self.inplace_backward,
|
||||||
process_group=self.process_group,
|
process_group=self.process_group,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -32,7 +32,7 @@ def cross_entropy_fwd_kernel(
|
|||||||
smoothing,
|
smoothing,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
lse_square_scale,
|
lse_square_scale,
|
||||||
ignored_index,
|
ignore_index,
|
||||||
total_classes,
|
total_classes,
|
||||||
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||||
n_cols, # shapes
|
n_cols, # shapes
|
||||||
@ -56,7 +56,7 @@ def cross_entropy_fwd_kernel(
|
|||||||
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
|
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
|
||||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||||
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
||||||
if label_idx == ignored_index:
|
if label_idx == ignore_index:
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
z_loss = 0.0
|
z_loss = 0.0
|
||||||
else:
|
else:
|
||||||
@ -104,7 +104,7 @@ def cross_entropy_bwd_kernel(
|
|||||||
smoothing,
|
smoothing,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
lse_square_scale,
|
lse_square_scale,
|
||||||
ignored_index,
|
ignore_index,
|
||||||
total_classes,
|
total_classes,
|
||||||
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||||
n_cols, # shapes
|
n_cols, # shapes
|
||||||
@ -120,7 +120,7 @@ def cross_entropy_bwd_kernel(
|
|||||||
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
|
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
|
||||||
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
label_idx = tl.load(labels_ptr + row_idx)
|
label_idx = tl.load(labels_ptr + row_idx)
|
||||||
if label_idx != ignored_index:
|
if label_idx != ignore_index:
|
||||||
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
|
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
|
||||||
else:
|
else:
|
||||||
dloss = 0.0
|
dloss = 0.0
|
||||||
@ -150,7 +150,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
smoothing=0.0,
|
smoothing=0.0,
|
||||||
logit_scale=1.0,
|
logit_scale=1.0,
|
||||||
lse_square_scale=0.0,
|
lse_square_scale=0.0,
|
||||||
ignored_index=-100,
|
ignore_index=-100,
|
||||||
inplace_backward=False,
|
inplace_backward=False,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
):
|
):
|
||||||
@ -192,7 +192,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
smoothing,
|
smoothing,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
lse_square_scale,
|
lse_square_scale,
|
||||||
ignored_index,
|
ignore_index,
|
||||||
total_classes,
|
total_classes,
|
||||||
class_start_idx,
|
class_start_idx,
|
||||||
n_cols, # shapes
|
n_cols, # shapes
|
||||||
@ -229,18 +229,18 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
losses += lse
|
losses += lse
|
||||||
if lse_square_scale != 0.0:
|
if lse_square_scale != 0.0:
|
||||||
z_losses = lse_square_scale * lse.square()
|
z_losses = lse_square_scale * lse.square()
|
||||||
z_losses.masked_fill_(labels == ignored_index, 0.0)
|
z_losses.masked_fill_(labels == ignore_index, 0.0)
|
||||||
losses += z_losses
|
losses += z_losses
|
||||||
else:
|
else:
|
||||||
z_losses = torch.zeros_like(losses)
|
z_losses = torch.zeros_like(losses)
|
||||||
losses.masked_fill_(labels == ignored_index, 0.0)
|
losses.masked_fill_(labels == ignore_index, 0.0)
|
||||||
|
|
||||||
ctx.save_for_backward(logits, lse, labels)
|
ctx.save_for_backward(logits, lse, labels)
|
||||||
ctx.mark_non_differentiable(z_losses)
|
ctx.mark_non_differentiable(z_losses)
|
||||||
ctx.smoothing = smoothing
|
ctx.smoothing = smoothing
|
||||||
ctx.logit_scale = logit_scale
|
ctx.logit_scale = logit_scale
|
||||||
ctx.lse_square_scale = lse_square_scale
|
ctx.lse_square_scale = lse_square_scale
|
||||||
ctx.ignored_index = ignored_index
|
ctx.ignore_index = ignore_index
|
||||||
ctx.total_classes = total_classes
|
ctx.total_classes = total_classes
|
||||||
ctx.class_start_idx = class_start_idx
|
ctx.class_start_idx = class_start_idx
|
||||||
ctx.inplace_backward = inplace_backward
|
ctx.inplace_backward = inplace_backward
|
||||||
@ -269,7 +269,7 @@ class CrossEntropyLoss(torch.autograd.Function):
|
|||||||
ctx.smoothing,
|
ctx.smoothing,
|
||||||
ctx.logit_scale,
|
ctx.logit_scale,
|
||||||
ctx.lse_square_scale,
|
ctx.lse_square_scale,
|
||||||
ctx.ignored_index,
|
ctx.ignore_index,
|
||||||
ctx.total_classes,
|
ctx.total_classes,
|
||||||
ctx.class_start_idx,
|
ctx.class_start_idx,
|
||||||
n_cols, # shapes
|
n_cols, # shapes
|
||||||
@ -287,7 +287,7 @@ def cross_entropy_loss(
|
|||||||
label_smoothing: float = 0.0,
|
label_smoothing: float = 0.0,
|
||||||
logit_scale: float = 1.0,
|
logit_scale: float = 1.0,
|
||||||
lse_square_scale: float = 0.0,
|
lse_square_scale: float = 0.0,
|
||||||
ignored_index=-100,
|
ignore_index=-100,
|
||||||
inplace_backward: bool = False,
|
inplace_backward: bool = False,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -299,7 +299,7 @@ def cross_entropy_loss(
|
|||||||
logit_scale: float. Multiply logits by this scale before calculating the loss.
|
logit_scale: float. Multiply logits by this scale before calculating the loss.
|
||||||
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||||
This is also referred to as "z-loss".
|
This is also referred to as "z-loss".
|
||||||
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
|
||||||
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
||||||
This saves memory.
|
This saves memory.
|
||||||
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
||||||
@ -314,7 +314,7 @@ def cross_entropy_loss(
|
|||||||
label_smoothing,
|
label_smoothing,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
lse_square_scale,
|
lse_square_scale,
|
||||||
ignored_index,
|
ignore_index,
|
||||||
inplace_backward,
|
inplace_backward,
|
||||||
process_group,
|
process_group,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user