From ec6d22143b5d375e253b2ebfc563b26a43f43684 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 26 Apr 2024 10:50:41 -0700 Subject: [PATCH] [CrossEntropy] Change ignored_index -> ignore_index --- flash_attn/losses/cross_entropy.py | 4 ++-- flash_attn/ops/triton/cross_entropy.py | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/flash_attn/losses/cross_entropy.py b/flash_attn/losses/cross_entropy.py index 2a1b77a..2c5032c 100644 --- a/flash_attn/losses/cross_entropy.py +++ b/flash_attn/losses/cross_entropy.py @@ -20,7 +20,7 @@ class CrossEntropyLoss(nn.Module): ): """ Arguments: - ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + ignore_index: int. If labels == ignore_index, the loss is set to 0.0. label_smoothing: float lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. This is also referred to as "z-loss". @@ -60,7 +60,7 @@ class CrossEntropyLoss(nn.Module): label_smoothing=self.label_smoothing, logit_scale=self.logit_scale, lse_square_scale=self.lse_square_scale, - ignored_index=self.ignore_index, + ignore_index=self.ignore_index, inplace_backward=self.inplace_backward, process_group=self.process_group, ) diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index c8111ca..1f895d7 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -32,7 +32,7 @@ def cross_entropy_fwd_kernel( smoothing, logit_scale, lse_square_scale, - ignored_index, + ignore_index, total_classes, class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes n_cols, # shapes @@ -56,7 +56,7 @@ def cross_entropy_fwd_kernel( sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) - if label_idx == ignored_index: + if label_idx == ignore_index: loss = 0.0 z_loss = 0.0 else: @@ -104,7 +104,7 @@ def cross_entropy_bwd_kernel( smoothing, logit_scale, lse_square_scale, - ignored_index, + ignore_index, total_classes, class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes n_cols, # shapes @@ -120,7 +120,7 @@ def cross_entropy_bwd_kernel( dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) label_idx = tl.load(labels_ptr + row_idx) - if label_idx != ignored_index: + if label_idx != ignore_index: dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) else: dloss = 0.0 @@ -150,7 +150,7 @@ class CrossEntropyLoss(torch.autograd.Function): smoothing=0.0, logit_scale=1.0, lse_square_scale=0.0, - ignored_index=-100, + ignore_index=-100, inplace_backward=False, process_group=None, ): @@ -192,7 +192,7 @@ class CrossEntropyLoss(torch.autograd.Function): smoothing, logit_scale, lse_square_scale, - ignored_index, + ignore_index, total_classes, class_start_idx, n_cols, # shapes @@ -229,18 +229,18 @@ class CrossEntropyLoss(torch.autograd.Function): losses += lse if lse_square_scale != 0.0: z_losses = lse_square_scale * lse.square() - z_losses.masked_fill_(labels == ignored_index, 0.0) + z_losses.masked_fill_(labels == ignore_index, 0.0) losses += z_losses else: z_losses = torch.zeros_like(losses) - losses.masked_fill_(labels == ignored_index, 0.0) + losses.masked_fill_(labels == ignore_index, 0.0) ctx.save_for_backward(logits, lse, labels) ctx.mark_non_differentiable(z_losses) ctx.smoothing = smoothing ctx.logit_scale = logit_scale ctx.lse_square_scale = lse_square_scale - ctx.ignored_index = ignored_index + ctx.ignore_index = ignore_index ctx.total_classes = total_classes ctx.class_start_idx = class_start_idx ctx.inplace_backward = inplace_backward @@ -269,7 +269,7 @@ class CrossEntropyLoss(torch.autograd.Function): ctx.smoothing, ctx.logit_scale, ctx.lse_square_scale, - ctx.ignored_index, + ctx.ignore_index, ctx.total_classes, ctx.class_start_idx, n_cols, # shapes @@ -287,7 +287,7 @@ def cross_entropy_loss( label_smoothing: float = 0.0, logit_scale: float = 1.0, lse_square_scale: float = 0.0, - ignored_index=-100, + ignore_index=-100, inplace_backward: bool = False, process_group=None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -299,7 +299,7 @@ def cross_entropy_loss( logit_scale: float. Multiply logits by this scale before calculating the loss. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. This is also referred to as "z-loss". - ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + ignore_index: int. If labels == ignore_index, the loss is set to 0.0. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. This saves memory. process_group: if not None, we're doing Tensor Parallel: each process is responsible for @@ -314,7 +314,7 @@ def cross_entropy_loss( label_smoothing, logit_scale, lse_square_scale, - ignored_index, + ignore_index, inplace_backward, process_group, )