diff --git a/flash_attn/losses/cross_entropy.py b/flash_attn/losses/cross_entropy.py index 2c5032c..d4dcf66 100644 --- a/flash_attn/losses/cross_entropy.py +++ b/flash_attn/losses/cross_entropy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Tri Dao. +# Copyright (c) 2024, Tri Dao. import torch import torch.nn as nn @@ -44,7 +44,7 @@ class CrossEntropyLoss(nn.Module): self.process_group = process_group self.return_z_loss = return_z_loss - def forward(self, input, target): + def forward(self, input, target, precomputed_lse=None): """ Arguments: input: (batch, vocab_size) @@ -57,6 +57,7 @@ class CrossEntropyLoss(nn.Module): loss, z_loss = cross_entropy_loss( input, target, + precomputed_lse=precomputed_lse, label_smoothing=self.label_smoothing, logit_scale=self.logit_scale, lse_square_scale=self.lse_square_scale, diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index e7bb686..8f7e9a2 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -39,25 +39,29 @@ def cross_entropy_fwd_kernel( HAS_SMOOTHING: tl.constexpr, # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE SPLIT: tl.constexpr, + PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0) ): row_idx = tl.program_id(0) logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) sum_logits = 0.0 # For smoothing - # Statistics for online softmax - m_i = -float("inf") - l_i = 0.0 - for col_offset in range(0, n_cols, BLOCK_SIZE): - cols = col_offset + tl.arange(0, BLOCK_SIZE) - logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - if HAS_SMOOTHING: - sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) - m_i_new = tl.maximum(m_i, tl.max(logits)) - l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) - m_i = m_i_new - lse = tl.log(l_i) + m_i - tl.store(lse_ptr + row_idx, lse) + if not PRECOMPUTED_LSE: + # Statistics for online softmax + m_i = -float("inf") + l_i = 0.0 + for col_offset in range(0, n_cols, BLOCK_SIZE): + cols = col_offset + tl.arange(0, BLOCK_SIZE) + logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + if HAS_SMOOTHING: + sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) + m_i_new = tl.maximum(m_i, tl.max(logits)) + l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) + m_i = m_i_new + lse = tl.log(l_i) + m_i + tl.store(lse_ptr + row_idx, lse) + else: + lse = tl.load(lse_ptr + row_idx) label_idx = tl.load(labels_ptr + row_idx) if label_idx == ignore_index: loss = 0.0 @@ -135,7 +139,7 @@ def cross_entropy_bwd_kernel( if HAS_SMOOTHING: smooth_positive = 1.0 - smoothing smooth_negative = smoothing / total_classes - probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative + probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative else: probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) @@ -148,6 +152,7 @@ class CrossEntropyLoss(torch.autograd.Function): ctx, logits, labels, + precomputed_lse=None, smoothing=0.0, logit_scale=1.0, lse_square_scale=0.0, @@ -161,6 +166,7 @@ class CrossEntropyLoss(torch.autograd.Function): total_classes = world_size * n_cols rank = 0 if process_group is None else torch.distributed.get_rank(process_group) class_start_idx = rank * n_cols + use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 if logits.stride(-1) != 1: logits = logits.contiguous() @@ -172,7 +178,11 @@ class CrossEntropyLoss(torch.autograd.Function): else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) ) losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if use_precomputed_lse: + assert precomputed_lse.shape == (n_rows,) + lse = precomputed_lse.contiguous() + else: + lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) @@ -192,8 +202,9 @@ class CrossEntropyLoss(torch.autograd.Function): n_cols, # shapes logits.stride(0), # strides BLOCK_SIZE=BLOCK_SIZE, # constants - num_warps=num_warps, SPLIT=world_size > 1, + PRECOMPUTED_LSE=use_precomputed_lse, + num_warps=num_warps, ) if world_size > 1: @@ -270,11 +281,13 @@ class CrossEntropyLoss(torch.autograd.Function): BLOCK_SIZE=BLOCK_SIZE, # constants num_warps=num_warps, ) - return dlogits, None, None, None, None, None, None, None, None + return dlogits, None, None, None, None, None, None, None, None, None + def cross_entropy_loss( logits: torch.Tensor, labels: torch.Tensor, + precomputed_lse: Optional[torch.Tensor] = None, label_smoothing: float = 0.0, logit_scale: float = 1.0, lse_square_scale: float = 0.0, @@ -302,6 +315,7 @@ def cross_entropy_loss( return CrossEntropyLoss.apply( logits, labels, + precomputed_lse, label_smoothing, logit_scale, lse_square_scale, diff --git a/tests/losses/test_cross_entropy.py b/tests/losses/test_cross_entropy.py index 9d67f59..7ccf5f4 100644 --- a/tests/losses/test_cross_entropy.py +++ b/tests/losses/test_cross_entropy.py @@ -1,9 +1,8 @@ -import math +# Copyright (c) 2024, Tri Dao. import pytest import torch import torch.nn.functional as F -from einops import rearrange from flash_attn.losses.cross_entropy import CrossEntropyLoss is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @@ -13,6 +12,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []) ) # @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("precompute_lse", [False, True]) +# @pytest.mark.parametrize("precompute_lse", [False]) @pytest.mark.parametrize("inplace_backward", [False, True]) # @pytest.mark.parametrize("inplace_backward", [False]) @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) @@ -22,11 +23,20 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 # @pytest.mark.parametrize("logit_scale", [1.0]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) # @pytest.mark.parametrize("smoothing", [0.0]) -@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split +@pytest.mark.parametrize("vocab_size", [50257, 128256]) # test vocab larger than 64k for split # @pytest.mark.parametrize("vocab_size", [12]) def test_cross_entropy_loss( - vocab_size, smoothing, logit_scale, lse_square_scale, return_z_loss, inplace_backward, dtype + vocab_size, + smoothing, + logit_scale, + lse_square_scale, + return_z_loss, + inplace_backward, + precompute_lse, + dtype, ): + if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0): + pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0") device = "cuda" rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) # set seed @@ -48,10 +58,15 @@ def test_cross_entropy_loss( return_z_loss=return_z_loss, inplace_backward=inplace_backward, ) - if return_z_loss: - out, out_z_loss = model(x, y) + if precompute_lse: + with torch.no_grad(): + lse = torch.logsumexp(x.float(), dim=-1) else: - out = model(x, y) + lse = None + if return_z_loss: + out, out_z_loss = model(x, y, precomputed_lse=lse) + else: + out = model(x, y, precomputed_lse=lse) x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float() out_pt = model_pt(x_pt_scaled, y) if lse_square_scale > 0.0: diff --git a/tests/losses/test_cross_entropy_parallel.py b/tests/losses/test_cross_entropy_parallel.py index c8b97fc..83362df 100644 --- a/tests/losses/test_cross_entropy_parallel.py +++ b/tests/losses/test_cross_entropy_parallel.py @@ -15,11 +15,13 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []) ) # @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("precompute_lse", [False, True]) +# @pytest.mark.parametrize("precompute_lse", [False]) @pytest.mark.parametrize("inplace_backward", [False, True]) # @pytest.mark.parametrize("inplace_backward", [False]) -@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) -# @pytest.mark.parametrize("lse_square_scale", [0.0]) -@pytest.mark.parametrize("logit_scale", [0.7]) +# @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) +@pytest.mark.parametrize("lse_square_scale", [1e-2]) +@pytest.mark.parametrize("logit_scale", [1.0, 0.7]) # @pytest.mark.parametrize("logit_scale", [1.0]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) # @pytest.mark.parametrize("smoothing", [0.0]) @@ -28,8 +30,17 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 # @pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [2]) def test_cross_entropy_loss_parallel( - vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype + vocab_size, + world_size, + smoothing, + logit_scale, + lse_square_scale, + inplace_backward, + precompute_lse, + dtype, ): + if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0): + pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0") assert vocab_size % world_size == 0 rtol, atol = ( (1e-5, 2e-5) @@ -67,7 +78,12 @@ def test_cross_entropy_loss_parallel( inplace_backward=inplace_backward, process_group=parallel_state.get_tensor_model_parallel_group(), ) - out = model(x, y) + if precompute_lse: + with torch.no_grad(): + lse = torch.logsumexp(x.float(), dim=-1) + else: + lse = None + out = model(x, y, precomputed_lse=lse) out_pt = model_pt(x_pt.float() * logit_scale, y) if lse_square_scale > 0.0: lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)