From 343492ec305d474bcf6e45bc05893bbc040fcc30 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 13 Nov 2022 17:27:26 -0800 Subject: [PATCH] Make nccl operations async in CrossEntropyLossParallel --- flash_attn/losses/cross_entropy_parallel.py | 58 ++++++++++++--------- tests/losses/test_cross_entropy_apex.py | 2 +- tests/losses/test_cross_entropy_parallel.py | 2 +- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/flash_attn/losses/cross_entropy_parallel.py b/flash_attn/losses/cross_entropy_parallel.py index 84fe82d..ebe4d1b 100644 --- a/flash_attn/losses/cross_entropy_parallel.py +++ b/flash_attn/losses/cross_entropy_parallel.py @@ -36,40 +36,50 @@ class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function): assert labels.shape == (batch,) rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() - vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( - partition_vocab_size, get_tensor_model_parallel_rank(), - get_tensor_model_parallel_world_size() - ) - # Create a mask of valid vocab ids (1 means it needs to be masked). - labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) - ignored_mask = labels == ignored_index - labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) - masked_labels = labels_local.clone() - masked_labels[labels_mask] = ignored_index + if world_size == 1: + losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing) + losses.masked_fill_(labels==ignored_index, 0) + labels_local = labels + else: + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, get_tensor_model_parallel_rank(), + get_tensor_model_parallel_world_size() + ) - losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing) - assert lse_local.shape == (batch,) - assert losses.shape == (batch,) - losses.masked_fill_(masked_labels==ignored_index, 0) + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) + masked_labels = labels_local.clone() + masked_labels[labels_mask] = ignored_index + + losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(masked_labels==ignored_index, 0) - if world_size > 1: lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, device=lse_local.device) - torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), - group=get_tensor_model_parallel_group()) + handle_lse = torch.distributed.all_gather_into_tensor( + lse_allgather, lse_local.contiguous(), + group=get_tensor_model_parallel_group(), async_op=True + ) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group(), async_op=True + ) + handle_lse.wait() lse = torch.logsumexp(lse_allgather, dim=0) - torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, - group=get_tensor_model_parallel_group()) - # The losses are currently lse_local - predicted_logit, we just have to subtract the - # lse_local and add the lse (global). - rank_per_sample = labels // partition_vocab_size + # The losses are going to be lse_local - predicted_logit, we just have to subtract + # the lse_local and add the lse (global). + rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode='floor') lse_local = lse_allgather[rank_per_sample, torch.arange(batch, device=lse_allgather.device)] + + handle_losses.wait() losses += lse - lse_local losses.masked_fill_(ignored_mask, 0) - else: - lse = lse_local ctx.save_for_backward(logits_parallel, lse, labels_local) ctx.smoothing = smoothing diff --git a/tests/losses/test_cross_entropy_apex.py b/tests/losses/test_cross_entropy_apex.py index e5e170f..646b539 100644 --- a/tests/losses/test_cross_entropy_apex.py +++ b/tests/losses/test_cross_entropy_apex.py @@ -6,7 +6,7 @@ import pytest from einops import rearrange -from src.losses.cross_entropy_apex import CrossEntropyLossApex +from flass_attn.losses.cross_entropy_apex import CrossEntropyLossApex is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 diff --git a/tests/losses/test_cross_entropy_parallel.py b/tests/losses/test_cross_entropy_parallel.py index 71fe89f..1f9ef45 100644 --- a/tests/losses/test_cross_entropy_parallel.py +++ b/tests/losses/test_cross_entropy_parallel.py @@ -10,7 +10,7 @@ import pytest from apex.transformer import parallel_state from apex.transformer import tensor_parallel -from src.losses.cross_entropy_parallel import CrossEntropyLossParallel +from flash_attn.losses.cross_entropy_parallel import CrossEntropyLossParallel is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8