From d8aacc510c20cc89ecf1afba1df234d239c0e37a Mon Sep 17 00:00:00 2001 From: "Curtis \"Fjord\" Hawthorne" Date: Sun, 21 Jan 2024 15:23:41 -0800 Subject: [PATCH] return z_loss (#768) --- flash_attn/losses/cross_entropy.py | 26 +++++++++++++++++---- flash_attn/ops/triton/cross_entropy.py | 31 ++++++++++++++++++++------ tests/losses/test_cross_entropy.py | 14 +++++++++--- 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/flash_attn/losses/cross_entropy.py b/flash_attn/losses/cross_entropy.py index e244047..2a1b77a 100644 --- a/flash_attn/losses/cross_entropy.py +++ b/flash_attn/losses/cross_entropy.py @@ -16,6 +16,7 @@ class CrossEntropyLoss(nn.Module): lse_square_scale=0.0, inplace_backward=False, process_group=None, + return_z_loss=False, ): """ Arguments: @@ -26,7 +27,10 @@ class CrossEntropyLoss(nn.Module): inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. This saves memory. process_group: if not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss will be aggregated across processes. + one part of the vocab. The loss will be aggregated across processes. + return_z_loss: bool. If True, we return the component of the loss contributed by + the lse_square_scale value. This value is only for logging and does not support + backprop. """ super().__init__() if reduction not in ["mean", "none", "sum"]: @@ -38,6 +42,7 @@ class CrossEntropyLoss(nn.Module): self.lse_square_scale = lse_square_scale self.inplace_backward = inplace_backward self.process_group = process_group + self.return_z_loss = return_z_loss def forward(self, input, target): """ @@ -46,9 +51,10 @@ class CrossEntropyLoss(nn.Module): target: (batch,) Returns: losses: (batch,) if reduction is 'none', else (1,), dtype float + z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) """ assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" - loss = cross_entropy_loss( + loss, z_loss = cross_entropy_loss( input, target, label_smoothing=self.label_smoothing, @@ -59,8 +65,20 @@ class CrossEntropyLoss(nn.Module): process_group=self.process_group, ) if self.reduction == "mean": - return loss.sum() / (target != self.ignore_index).sum() + loss = loss.sum() / (target != self.ignore_index).sum() elif self.reduction == "sum": - return loss.sum() + loss = loss.sum() else: + loss = loss + + if not self.return_z_loss: return loss + + if self.reduction == "mean": + z_loss = z_loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + z_loss = z_loss.sum() + else: + z_loss = z_loss + + return loss, z_loss diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index afa02a2..c8111ca 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -26,6 +26,7 @@ if "all_gather_into_tensor" not in dir(torch.distributed): def cross_entropy_fwd_kernel( loss_ptr, # data ptrs lse_ptr, + z_loss_ptr, logits_ptr, labels_ptr, smoothing, @@ -57,6 +58,7 @@ def cross_entropy_fwd_kernel( tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) if label_idx == ignored_index: loss = 0.0 + z_loss = 0.0 else: label_idx -= class_start_idx if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( @@ -78,8 +80,13 @@ def cross_entropy_fwd_kernel( else: loss = 0.0 if not SPLIT: - loss += lse_square_scale * lse * lse + z_loss = lse_square_scale * lse * lse + loss += z_loss + else: + z_loss = 0.0 tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) + if not SPLIT: + tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) @triton.heuristics( @@ -172,12 +179,14 @@ class CrossEntropyLoss(torch.autograd.Function): loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(logits.device.index): cross_entropy_fwd_kernel[(n_rows, n_splits)]( losses, # data ptrs lse, + z_losses, logits, labels, smoothing, @@ -219,10 +228,15 @@ class CrossEntropyLoss(torch.autograd.Function): # Again, we just have to add the (global) lse. losses += lse if lse_square_scale != 0.0: - losses += lse_square_scale * lse.square() + z_losses = lse_square_scale * lse.square() + z_losses.masked_fill_(labels == ignored_index, 0.0) + losses += z_losses + else: + z_losses = torch.zeros_like(losses) losses.masked_fill_(labels == ignored_index, 0.0) ctx.save_for_backward(logits, lse, labels) + ctx.mark_non_differentiable(z_losses) ctx.smoothing = smoothing ctx.logit_scale = logit_scale ctx.lse_square_scale = lse_square_scale @@ -230,10 +244,13 @@ class CrossEntropyLoss(torch.autograd.Function): ctx.total_classes = total_classes ctx.class_start_idx = class_start_idx ctx.inplace_backward = inplace_backward - return losses + + return losses, z_losses @staticmethod - def backward(ctx, grad_losses): + def backward(ctx, grad_losses, grad_z_losses): + del grad_z_losses # z_losses are only for logging. + logits, lse, labels = ctx.saved_tensors dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) n_rows, n_cols = logits.shape @@ -262,8 +279,7 @@ class CrossEntropyLoss(torch.autograd.Function): BLOCK_SIZE=BLOCK_SIZE, # constants num_warps=num_warps, ) - return dlogits, None, None, None, None, None, None, None - + return dlogits, None, None, None, None, None, None, None, None def cross_entropy_loss( logits: torch.Tensor, @@ -287,9 +303,10 @@ def cross_entropy_loss( inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. This saves memory. process_group: if not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss will be aggregated across processes. + one part of the vocab. The loss will be aggregated across processes. Returns: losses: (batch,), float + z_losses: (batch,), float """ return CrossEntropyLoss.apply( logits, diff --git a/tests/losses/test_cross_entropy.py b/tests/losses/test_cross_entropy.py index edab0ff..9d67f59 100644 --- a/tests/losses/test_cross_entropy.py +++ b/tests/losses/test_cross_entropy.py @@ -16,6 +16,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("inplace_backward", [False, True]) # @pytest.mark.parametrize("inplace_backward", [False]) @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) +@pytest.mark.parametrize("return_z_loss", [False, True]) # @pytest.mark.parametrize("lse_square_scale", [1e-2]) @pytest.mark.parametrize("logit_scale", [1.0, 0.7]) # @pytest.mark.parametrize("logit_scale", [1.0]) @@ -24,7 +25,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split # @pytest.mark.parametrize("vocab_size", [12]) def test_cross_entropy_loss( - vocab_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype + vocab_size, smoothing, logit_scale, lse_square_scale, return_z_loss, inplace_backward, dtype ): device = "cuda" rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) @@ -44,14 +45,21 @@ def test_cross_entropy_loss( label_smoothing=smoothing, logit_scale=logit_scale, lse_square_scale=lse_square_scale, + return_z_loss=return_z_loss, inplace_backward=inplace_backward, ) - out = model(x, y) + if return_z_loss: + out, out_z_loss = model(x, y) + else: + out = model(x, y) x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float() out_pt = model_pt(x_pt_scaled, y) if lse_square_scale > 0.0: lse_pt = torch.logsumexp(x_pt_scaled, dim=-1) - out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean() + z_loss_pt = lse_square_scale * (lse_pt[y != -100] ** 2).mean() + if return_z_loss: + assert torch.allclose(out_z_loss, z_loss_pt, rtol=rtol, atol=atol) + out_pt += z_loss_pt assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) g = torch.randn_like(out)