From aaa14741296e74d17cb3f27c6d8d72a29c66b60f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 19 Nov 2023 23:19:36 -0800 Subject: [PATCH] [CrossEntropy] Simplify the case of large vocab with Tensor Parallel --- flash_attn/ops/triton/cross_entropy.py | 9 +- tests/losses/test_cross_entropy_parallel.py | 13 ++- ...test_cross_entropy_parallel_large_vocab.py | 86 ------------------- 3 files changed, 10 insertions(+), 98 deletions(-) delete mode 100644 tests/losses/test_cross_entropy_parallel_large_vocab.py diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index b0e982b..21b099e 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -196,18 +196,17 @@ class CrossEntropyLoss(torch.autograd.Function): # -0.9 * predicted logit - 0.1 * sum logit / total_classes. # For labels not in the vocab of this partition, losses contains # -0.1 * sum logit / total_classes. + if n_splits > 1: + lse = torch.logsumexp(lse, dim=0) + losses = losses.sum(dim=0) if world_size > 1: - lse_allgather = torch.empty(world_size * n_splits, n_rows, dtype=lse.dtype, device=lse.device) + lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) - if n_splits > 1: losses = losses.sum(dim=0) handle_losses = torch.distributed.all_reduce( losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True ) lse = torch.logsumexp(lse_allgather, dim=0) handle_losses.wait() - else: - lse = torch.logsumexp(lse, dim=0) - losses = losses.sum(dim=0) # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, # we just have to add the (global) lse. # If there's smoothing=0.1, the total losses are diff --git a/tests/losses/test_cross_entropy_parallel.py b/tests/losses/test_cross_entropy_parallel.py index 2588a11..d26f53a 100644 --- a/tests/losses/test_cross_entropy_parallel.py +++ b/tests/losses/test_cross_entropy_parallel.py @@ -1,11 +1,10 @@ # Run test with: -# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/losses/test_cross_entropy_parallel.py +# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel.py import math import pytest import torch -import torch.nn.functional as F from apex.transformer import parallel_state, tensor_parallel from flash_attn.losses.cross_entropy import CrossEntropyLoss @@ -19,19 +18,19 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("inplace_backward", [False, True]) # @pytest.mark.parametrize("inplace_backward", [False]) @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) -# @pytest.mark.parametrize("lse_square_scale", [1e-2]) +# @pytest.mark.parametrize("lse_square_scale", [0.0]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) # @pytest.mark.parametrize("smoothing", [0.0]) -@pytest.mark.parametrize("vocab_size", [50264, 128 * 1024]) # test vocab larger than 64k for split +@pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split # @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split -@pytest.mark.parametrize("world_size", [1, 2, 4]) -# @pytest.mark.parametrize("world_size", [2]) +# @pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("world_size", [2]) def test_cross_entropy_loss_parallel( vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype ): assert vocab_size % world_size == 0 rtol, atol = ( - (1e-5, 1e-6) + (1e-5, 2e-5) if dtype == torch.float32 else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)) ) diff --git a/tests/losses/test_cross_entropy_parallel_large_vocab.py b/tests/losses/test_cross_entropy_parallel_large_vocab.py deleted file mode 100644 index 340927b..0000000 --- a/tests/losses/test_cross_entropy_parallel_large_vocab.py +++ /dev/null @@ -1,86 +0,0 @@ -# Run test with: -# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel_large_vocab.py - -import math - -import pytest -import torch -import torch.nn.functional as F -from apex.transformer import parallel_state, tensor_parallel -from flash_attn.losses.cross_entropy import CrossEntropyLoss - -is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 - - -@pytest.mark.parametrize( - "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []) -) -# @pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("inplace_backward", [False, True]) -# @pytest.mark.parametrize("inplace_backward", [False]) -@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) -# @pytest.mark.parametrize("lse_square_scale", [1e-2]) -@pytest.mark.parametrize("smoothing", [0.0, 0.9]) -# @pytest.mark.parametrize("smoothing", [0.0]) -@pytest.mark.parametrize("vocab_size", [256 * 1024]) # test vocab larger than 64k for split -# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split -@pytest.mark.parametrize("world_size", [2]) -# @pytest.mark.parametrize("world_size", [2]) -def test_cross_entropy_loss_parallel( - vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype -): - assert vocab_size % world_size == 0 - rtol, atol = ( - (1e-5, 1e-6) - if dtype == torch.float32 - else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)) - ) - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl", init_method="env://") - partition_vocab_size = vocab_size // world_size - device = f"cuda:{torch.distributed.get_rank()}" - assert world_size <= torch.distributed.get_world_size() - parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) - rank = parallel_state.get_tensor_model_parallel_rank() - # set seed - torch.random.manual_seed(0) - batch_size = 8 - seqlen = 128 - x_pt = ( - torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10 - ).requires_grad_() - x = ( - tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt) - .detach() - .clone() - .requires_grad_() - ) - y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) - y[torch.randperm(batch_size * seqlen)[:10]] = -100 - model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none") - model = CrossEntropyLoss( - label_smoothing=smoothing, - reduction="none", - lse_square_scale=lse_square_scale, - inplace_backward=inplace_backward, - process_group=parallel_state.get_tensor_model_parallel_group(), - ) - out = model(x, y) - out_pt = model_pt(x_pt.float(), y) - if lse_square_scale > 0.0: - lse_pt = torch.logsumexp(x_pt.float(), dim=-1) - out_pt += lse_square_scale * lse_pt.square() - out_pt.masked_fill_(y == -100, 0.0) - assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) - - g = torch.randn_like(out) - out_pt.backward(g) - out.backward(g) - assert torch.allclose( - x.grad, - x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size], - rtol=rtol, - atol=atol, - ) - - parallel_state.destroy_model_parallel()