From 08124c8f9cf88ba327e2c455ed2a302979f06c91 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 16 Dec 2023 18:39:37 -0800 Subject: [PATCH] [CrossEntropy] Implement logit_scale option --- flash_attn/losses/cross_entropy.py | 3 +++ flash_attn/ops/triton/cross_entropy.py | 20 +++++++++++++++----- tests/losses/test_cross_entropy.py | 11 ++++++++--- tests/losses/test_cross_entropy_parallel.py | 9 ++++++--- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/flash_attn/losses/cross_entropy.py b/flash_attn/losses/cross_entropy.py index 93c8e96..e244047 100644 --- a/flash_attn/losses/cross_entropy.py +++ b/flash_attn/losses/cross_entropy.py @@ -12,6 +12,7 @@ class CrossEntropyLoss(nn.Module): ignore_index=-100, reduction="mean", label_smoothing=0.0, + logit_scale=1.0, lse_square_scale=0.0, inplace_backward=False, process_group=None, @@ -33,6 +34,7 @@ class CrossEntropyLoss(nn.Module): self.ignore_index = ignore_index self.reduction = reduction self.label_smoothing = label_smoothing + self.logit_scale = logit_scale self.lse_square_scale = lse_square_scale self.inplace_backward = inplace_backward self.process_group = process_group @@ -50,6 +52,7 @@ class CrossEntropyLoss(nn.Module): input, target, label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, lse_square_scale=self.lse_square_scale, ignored_index=self.ignore_index, inplace_backward=self.inplace_backward, diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 21b099e..afa02a2 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -29,6 +29,7 @@ def cross_entropy_fwd_kernel( logits_ptr, labels_ptr, smoothing, + logit_scale, lse_square_scale, ignored_index, total_classes, @@ -48,7 +49,7 @@ def cross_entropy_fwd_kernel( label_idx = tl.load(labels_ptr + row_idx) logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( tl.float32 - ) + ) * logit_scale max_logits = tl.max(logits, 0) if HAS_SMOOTHING: sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) @@ -61,7 +62,7 @@ def cross_entropy_fwd_kernel( if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( n_cols, (col_block_idx + 1) * BLOCK_SIZE ): - logits_label = tl.load(logits_ptr + label_idx) + logits_label = tl.load(logits_ptr + label_idx) * logit_scale if HAS_SMOOTHING: loss = ( (lse if not SPLIT else 0.0) @@ -94,6 +95,7 @@ def cross_entropy_bwd_kernel( lse_ptr, labels_ptr, smoothing, + logit_scale, lse_square_scale, ignored_index, total_classes, @@ -117,7 +119,7 @@ def cross_entropy_bwd_kernel( dloss = 0.0 logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( tl.float32 - ) + ) * logit_scale lse = tl.load(lse_ptr + row_idx) probs = tl.exp(logits - lse) probs += 2.0 * lse_square_scale * lse * probs @@ -128,16 +130,18 @@ def cross_entropy_bwd_kernel( probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative else: probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) - tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols) + tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) class CrossEntropyLoss(torch.autograd.Function): + @staticmethod def forward( ctx, logits, labels, - smoothing, + smoothing=0.0, + logit_scale=1.0, lse_square_scale=0.0, ignored_index=-100, inplace_backward=False, @@ -177,6 +181,7 @@ class CrossEntropyLoss(torch.autograd.Function): logits, labels, smoothing, + logit_scale, lse_square_scale, ignored_index, total_classes, @@ -219,6 +224,7 @@ class CrossEntropyLoss(torch.autograd.Function): ctx.save_for_backward(logits, lse, labels) ctx.smoothing = smoothing + ctx.logit_scale = logit_scale ctx.lse_square_scale = lse_square_scale ctx.ignored_index = ignored_index ctx.total_classes = total_classes @@ -244,6 +250,7 @@ class CrossEntropyLoss(torch.autograd.Function): lse, labels, ctx.smoothing, + ctx.logit_scale, ctx.lse_square_scale, ctx.ignored_index, ctx.total_classes, @@ -262,6 +269,7 @@ def cross_entropy_loss( logits: torch.Tensor, labels: torch.Tensor, label_smoothing: float = 0.0, + logit_scale: float = 1.0, lse_square_scale: float = 0.0, ignored_index=-100, inplace_backward: bool = False, @@ -272,6 +280,7 @@ def cross_entropy_loss( logits: (batch, vocab_size) labels: (batch,) label_smoothing: float + logit_scale: float. Multiply logits by this scale before calculating the loss. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. This is also referred to as "z-loss". ignored_index: int. If labels == ignored_index, the loss is set to 0.0. @@ -286,6 +295,7 @@ def cross_entropy_loss( logits, labels, label_smoothing, + logit_scale, lse_square_scale, ignored_index, inplace_backward, diff --git a/tests/losses/test_cross_entropy.py b/tests/losses/test_cross_entropy.py index 21c43b0..4e154d9 100644 --- a/tests/losses/test_cross_entropy.py +++ b/tests/losses/test_cross_entropy.py @@ -17,11 +17,15 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 # @pytest.mark.parametrize("inplace_backward", [False]) @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) # @pytest.mark.parametrize("lse_square_scale", [1e-2]) +@pytest.mark.parametrize("logit_scale", [1.0, 0.7]) +# @pytest.mark.parametrize("logit_scale", [1.0]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) # @pytest.mark.parametrize("smoothing", [0.0]) @pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split # @pytest.mark.parametrize("vocab_size", [12]) -def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype): +def test_cross_entropy_loss( + vocab_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype +): device = "cuda" rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) # set seed @@ -38,13 +42,14 @@ def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_bac model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing) model = CrossEntropyLoss( label_smoothing=smoothing, + logit_scale=logit_scale, lse_square_scale=lse_square_scale, inplace_backward=inplace_backward, ) out = model(x, y) - out_pt = model_pt(x_pt.float(), y) + out_pt = model_pt(x_pt.float() * logit_scale, y) if lse_square_scale > 0.0: - lse_pt = torch.logsumexp(x_pt.float(), dim=-1) + lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1) out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean() assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) diff --git a/tests/losses/test_cross_entropy_parallel.py b/tests/losses/test_cross_entropy_parallel.py index d26f53a..c8b97fc 100644 --- a/tests/losses/test_cross_entropy_parallel.py +++ b/tests/losses/test_cross_entropy_parallel.py @@ -19,6 +19,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 # @pytest.mark.parametrize("inplace_backward", [False]) @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) # @pytest.mark.parametrize("lse_square_scale", [0.0]) +@pytest.mark.parametrize("logit_scale", [0.7]) +# @pytest.mark.parametrize("logit_scale", [1.0]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) # @pytest.mark.parametrize("smoothing", [0.0]) @pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split @@ -26,7 +28,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 # @pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [2]) def test_cross_entropy_loss_parallel( - vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype + vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype ): assert vocab_size % world_size == 0 rtol, atol = ( @@ -59,15 +61,16 @@ def test_cross_entropy_loss_parallel( model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none") model = CrossEntropyLoss( label_smoothing=smoothing, + logit_scale=logit_scale, reduction="none", lse_square_scale=lse_square_scale, inplace_backward=inplace_backward, process_group=parallel_state.get_tensor_model_parallel_group(), ) out = model(x, y) - out_pt = model_pt(x_pt.float(), y) + out_pt = model_pt(x_pt.float() * logit_scale, y) if lse_square_scale > 0.0: - lse_pt = torch.logsumexp(x_pt.float(), dim=-1) + lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1) out_pt += lse_square_scale * lse_pt.square() out_pt.masked_fill_(y == -100, 0.0) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)