From dff68c2b228234e34714a6cb1b966cb3a09496b9 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 23 Dec 2022 14:51:08 -0800 Subject: [PATCH] Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss --- csrc/xentropy/interface.cpp | 30 ++-- csrc/xentropy/xentropy_kernel.cu | 50 ++++--- flash_attn/losses/cross_entropy.py | 128 ++++++++++++++++++ flash_attn/losses/cross_entropy_apex.py | 51 ------- flash_attn/losses/cross_entropy_parallel.py | 122 ----------------- flash_attn/models/bert.py | 8 +- ..._entropy_apex.py => test_cross_entropy.py} | 9 +- tests/losses/test_cross_entropy_parallel.py | 15 +- training/configs/experiment/owt/base.yaml | 2 +- training/configs/experiment/pile/base.yaml | 2 +- training/src/losses/cross_entropy.py | 128 ++++++++++++++++++ training/src/losses/cross_entropy_apex.py | 51 ------- training/src/losses/cross_entropy_parallel.py | 112 --------------- training/src/metrics/perplexity.py | 2 +- 14 files changed, 324 insertions(+), 386 deletions(-) create mode 100644 flash_attn/losses/cross_entropy.py delete mode 100644 flash_attn/losses/cross_entropy_apex.py delete mode 100644 flash_attn/losses/cross_entropy_parallel.py rename tests/losses/{test_cross_entropy_apex.py => test_cross_entropy.py} (77%) create mode 100644 training/src/losses/cross_entropy.py delete mode 100644 training/src/losses/cross_entropy_apex.py delete mode 100644 training/src/losses/cross_entropy_parallel.py diff --git a/csrc/xentropy/interface.cpp b/csrc/xentropy/interface.cpp index 715790d..41a783f 100644 --- a/csrc/xentropy/interface.cpp +++ b/csrc/xentropy/interface.cpp @@ -4,7 +4,8 @@ std::vector softmax_xentropy_cuda( const at::Tensor &input, const at::Tensor &labels, - const float smoothing); + const float smoothing, + const int total_classes); at::Tensor softmax_xentropy_backward_cuda( const at::Tensor &grad_loss, @@ -12,7 +13,8 @@ at::Tensor softmax_xentropy_backward_cuda( const at::Tensor &max_log_sum_exp, const at::Tensor &labels, const float smoothing, - const bool inplace); + const bool inplace, + const int total_classes); // C++ interface @@ -23,11 +25,15 @@ at::Tensor softmax_xentropy_backward_cuda( std::vector softmax_xentropy_forward( const at::Tensor &input, const at::Tensor &labels, - const float smoothing) { - CHECK_CUDA(input); + const float smoothing, + const int total_classes=-1) { + // For tensor parallel cross entropy with smoothing, we want to pass in the total number + // of classes so that smoothing can be applied correctly. If total_classes=-1, use the + // last dimension of the input tensor. + CHECK_INPUT(input); CHECK_INPUT(labels); - return softmax_xentropy_cuda(input, labels, smoothing); + return softmax_xentropy_cuda(input, labels, smoothing, total_classes); } at::Tensor softmax_xentropy_backward( @@ -36,16 +42,18 @@ at::Tensor softmax_xentropy_backward( const at::Tensor &max_log_sum_exp, const at::Tensor &labels, const float smoothing, - const bool inplace) { - CHECK_CUDA(grad_loss); - CHECK_CUDA(logits); + const bool inplace, + const int total_classes=-1) { + CHECK_INPUT(grad_loss); + CHECK_INPUT(logits); CHECK_INPUT(max_log_sum_exp); CHECK_INPUT(labels); - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace); + return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, + smoothing, inplace, total_classes); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)"); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)"); + m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1); + m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1); } diff --git a/csrc/xentropy/xentropy_kernel.cu b/csrc/xentropy/xentropy_kernel.cu index b1ebf70..8d8836e 100644 --- a/csrc/xentropy/xentropy_kernel.cu +++ b/csrc/xentropy/xentropy_kernel.cu @@ -434,7 +434,8 @@ cunn_SoftMaxXEntropyForward( scalar_t *input, int64_t *labels, int64_t classes, - const float smoothing) + const float smoothing, + const int total_classes) { extern __shared__ unsigned char smem[]; auto sdata = reinterpret_cast(smem); @@ -472,12 +473,8 @@ cunn_SoftMaxXEntropyForward( // reserve max + log_sum_exp for bprop if (threadIdx.x == 0) { accscalar_t lse = max_k + std::log(sumAll); - if ((label >= 0) && (label < classes)) { - accscalar_t log_prob = epilogue(static_cast(input[label])); - losses[blockIdx.x] = (lse - sum_k / classes) * smoothing - log_prob * (1 - smoothing); - } else { - losses[blockIdx.x] = outscalar_t(0.f); - } + accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast(input[label])) : 0.f; + losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing); max_log_sum_exp[blockIdx.x] = lse; } } @@ -490,10 +487,11 @@ apply(scalar_t *gradInput, outscalar_t *gradOutput, int64_t *labels, const float smoothing, - int classes) + int classes, + const int total_classes) { accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / classes; + accscalar_t smooth_negatives = smoothing / total_classes; accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; int64_t label = labels[blockIdx.x]; accscalar_t coeff = max_log_sum_exp[blockIdx.x]; @@ -534,10 +532,11 @@ aligned_apply(int shift, outscalar_t *gradOutput, int64_t *labels, const float smoothing, - int classes) + int classes, + const int total_classes) { accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / classes; + accscalar_t smooth_negatives = smoothing / total_classes; accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; int64_t label = labels[blockIdx.x]; accscalar_t coeff = max_log_sum_exp[blockIdx.x]; @@ -602,7 +601,8 @@ cunn_SoftMaxXEntropyBackward( outscalar_t *gradOutput, int64_t *labels, const float smoothing, - int classes) + int classes, + const int total_classes) { gradInput += blockIdx.x * classes; logits += blockIdx.x * classes; @@ -611,10 +611,10 @@ cunn_SoftMaxXEntropyBackward( const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); if (shift == shift_){ - aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes); + aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); } else { - apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes); + apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); } } @@ -623,7 +623,11 @@ template class Epilogue> std::vector host_softmax_xentropy( const Tensor & input_, const Tensor & labels_, - const float smoothing){ + const float smoothing, + const int total_classes) { + // For tensor parallel cross entropy with smoothing, we want to pass in the total number + // of classes so that smoothing can be applied correctly. If total_classes=-1, use the + // last dimension of the input tensor. AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); // Otherwise the kernel will be launched from cuda:0 device @@ -666,7 +670,7 @@ std::vector host_softmax_xentropy( <<>>( losses.data_ptr(), max_log_sum_exp.data_ptr(), input.data_ptr(), labels_.data_ptr(), - dim_size, smoothing + dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes ); ); @@ -683,7 +687,8 @@ Tensor host_softmax_xentropy_backward( const at::Tensor &max_log_sum_exp, const at::Tensor &labels, const float smoothing, - bool inplace) { + bool inplace, + const int total_classes) { // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()}; @@ -730,7 +735,7 @@ Tensor host_softmax_xentropy_backward( gI.data_ptr(), logits.data_ptr(), max_log_sum_exp.data_ptr(), grad.data_ptr(), labels.data_ptr(), - smoothing, dim_size + smoothing, dim_size, total_classes ); ); @@ -738,8 +743,8 @@ Tensor host_softmax_xentropy_backward( return gI; } -std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing){ - return host_softmax_xentropy(input, labels, smoothing); +std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){ + return host_softmax_xentropy(input, labels, smoothing, total_classes); } at::Tensor softmax_xentropy_backward_cuda( @@ -748,7 +753,8 @@ at::Tensor softmax_xentropy_backward_cuda( const at::Tensor &max_log_sum_exp, const at::Tensor &labels, const float smoothing, - const bool inplace) { + const bool inplace, + const int total_classes) { AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); - return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace); + return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes); } diff --git a/flash_attn/losses/cross_entropy.py b/flash_attn/losses/cross_entropy.py new file mode 100644 index 0000000..48e2f2f --- /dev/null +++ b/flash_attn/losses/cross_entropy.py @@ -0,0 +1,128 @@ +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py +# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and +# the losses we can get the global loss. There's no need to do it step by step +# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) +# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py +import torch +import torch.nn as nn + +import xentropy_cuda_lib + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False, + process_group=None): + """ + logits: (batch, vocab_size) + labels: (batch,) + If process_group is not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss needs to be aggregated across processes. + """ + batch, vocab_size = logits.shape + assert labels.shape == (batch,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + ctx.total_classes = world_size * vocab_size + + if world_size == 1: + losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) + losses.masked_fill_(labels==ignored_index, 0) + labels_local = labels + else: + rank = torch.distributed.get_rank(process_group) + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) + + # For tensor parallel cross entropy with smoothing, we want to pass in the total number + # of classes so that smoothing can be applied correctly. If total_classes=-1, use the + # last dimension of the input tensor. + losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing, + world_size * vocab_size) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(ignored_mask, 0) + # For labels == ignored_index, the loss is always 0. + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # lse_local - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) + # For labels not in the vocab of this partition, losses contains + # 0.1 * (lse_local - sum logit / total_classes). + + lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, + device=lse_local.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), + group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + # If there's no smoothing, the total losses are lse_local - predicted_logit, + # we just have to subtract the lse_local and add the lse (global). + # If there's smoothing=0.1, the total losses are + # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) + # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). + rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor') + lse_local = lse_allgather[rank_per_sample, + torch.arange(batch, device=lse_allgather.device)] + + handle_losses.wait() + if smoothing == 0.0: + losses += lse - lse_local + else: + losses += ((1 - smoothing) * (lse - lse_local) + + smoothing * (lse - lse_allgather.sum(dim=0))) + losses.masked_fill_(ignored_mask, 0) + + ctx.save_for_backward(logits, lse, labels_local) + ctx.smoothing = smoothing + ctx.ignored_index = ignored_index + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits, lse, labels = ctx.saved_tensors + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels==ctx.ignored_index, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, + ctx.smoothing, ctx.inplace_backward, + ctx.total_classes) + return grad_logits, None, None, None, None, None, None + + +class CrossEntropyLoss(nn.Module): + + def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, + inplace_backward=False): + super().__init__() + if reduction not in ['mean', 'none']: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + + def forward(self, input, target, process_group=None): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply( + input, target, self.label_smoothing, self.ignore_index, self.inplace_backward, + process_group + ) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/flash_attn/losses/cross_entropy_apex.py b/flash_attn/losses/cross_entropy_apex.py deleted file mode 100644 index ef70946..0000000 --- a/flash_attn/losses/cross_entropy_apex.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import torch.nn as nn - -import xentropy_cuda_lib - - -# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py -class SoftmaxCrossEntropyLossFn(torch.autograd.Function): - @staticmethod - def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False): - losses, max_log_sum_exp = xentropy_cuda_lib.forward( - logits, labels, smoothing) - losses.masked_fill_(labels==padding_idx, 0) - ctx.save_for_backward(logits, max_log_sum_exp, labels) - ctx.smoothing = smoothing - ctx.padding_idx = padding_idx - ctx.inplace_backward = inplace_backward - return losses - - @staticmethod - def backward(ctx, grad_loss): - logits, max_log_sum_exp, labels = ctx.saved_tensors - if not grad_loss.is_contiguous(): - grad_loss = grad_loss.contiguous() - grad_loss.masked_fill_(labels==ctx.padding_idx, 0) - grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels, - ctx.smoothing, ctx.inplace_backward) - return grad_logits, None, None, None, None - - -class CrossEntropyLossApex(nn.Module): - - def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, - inplace_backward=False): - super().__init__() - if reduction not in ['mean', 'none']: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.inplace_backward = inplace_backward - - def forward(self, input, target): - assert input.is_cuda and target.is_cuda - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing, - self.ignore_index, self.inplace_backward) - if self.reduction == 'mean': - return loss.sum() / (target != self.ignore_index).sum() - else: - return loss diff --git a/flash_attn/losses/cross_entropy_parallel.py b/flash_attn/losses/cross_entropy_parallel.py deleted file mode 100644 index b9f5c59..0000000 --- a/flash_attn/losses/cross_entropy_parallel.py +++ /dev/null @@ -1,122 +0,0 @@ -# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py -# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and -# the losses we can get the global loss. There's no need to do it step by step -# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) -import torch -import torch.nn as nn - -import xentropy_cuda_lib - -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.parallel_state import get_tensor_model_parallel_world_size -from apex.transformer.tensor_parallel.utils import VocabUtility - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 4 lines are for backward compatibility with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base -if "reduce_scatter_tensor" not in dir(torch.distributed): - torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base - - -class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function): - - @staticmethod - def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100, - inplace_backward=False): - """ - logits_parallel: (batch, vocab_size / world_size) - labels: (batch,) - """ - assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it' - batch, partition_vocab_size = logits_parallel.shape - assert labels.shape == (batch,) - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - - if world_size == 1: - losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing) - losses.masked_fill_(labels==ignored_index, 0) - labels_local = labels - else: - vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( - partition_vocab_size, get_tensor_model_parallel_rank(), - get_tensor_model_parallel_world_size() - ) - - # Create a mask of valid vocab ids (1 means it needs to be masked). - labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) - ignored_mask = labels == ignored_index - labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) - masked_labels = labels_local.clone() - masked_labels[labels_mask] = ignored_index - - losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing) - assert lse_local.shape == (batch,) - assert losses.shape == (batch,) - losses.masked_fill_(masked_labels==ignored_index, 0) - - lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, - device=lse_local.device) - handle_lse = torch.distributed.all_gather_into_tensor( - lse_allgather, lse_local.contiguous(), - group=get_tensor_model_parallel_group(), async_op=True - ) - handle_losses = torch.distributed.all_reduce( - losses, op=torch.distributed.ReduceOp.SUM, - group=get_tensor_model_parallel_group(), async_op=True - ) - handle_lse.wait() - lse = torch.logsumexp(lse_allgather, dim=0) - # The losses are going to be lse_local - predicted_logit, we just have to subtract - # the lse_local and add the lse (global). - rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode='floor') - lse_local = lse_allgather[rank_per_sample, - torch.arange(batch, device=lse_allgather.device)] - - handle_losses.wait() - losses += lse - lse_local - losses.masked_fill_(ignored_mask, 0) - - ctx.save_for_backward(logits_parallel, lse, labels_local) - ctx.smoothing = smoothing - ctx.ignored_index = ignored_index - ctx.inplace_backward = inplace_backward - return losses - - @staticmethod - def backward(ctx, grad_loss): - logits_parallel, lse, labels = ctx.saved_tensors - if not grad_loss.is_contiguous(): - grad_loss = grad_loss.contiguous() - grad_loss.masked_fill_(labels==ctx.ignored_index, 0) - grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels, - ctx.smoothing, ctx.inplace_backward) - return grad_logits, None, None, None, None, None - - -class CrossEntropyLossParallel(nn.Module): - - def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, - inplace_backward=False): - super().__init__() - if reduction not in ['mean', 'none']: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.inplace_backward = inplace_backward - - def forward(self, input, target): - assert input.is_cuda and target.is_cuda - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = SoftmaxCrossEntropyLossParallelFn.apply( - input, target, self.label_smoothing, self.ignore_index, self.inplace_backward - ) - if self.reduction == 'mean': - return loss.sum() / (target != self.ignore_index).sum() - else: - return loss diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index d262f9f..b08769e 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -40,9 +40,9 @@ except ImportError: dropout_add_layer_norm, layer_norm = None, None try: - from flash_attn.losses.cross_entropy_apex import CrossEntropyLossApex + from flash_attn.losses.cross_entropy import CrossEntropyLoss except ImportError: - CrossEntropyLossApex = None + CrossEntropyLoss = None logger = logging.getLogger(__name__) @@ -374,10 +374,10 @@ class BertForPreTraining(BertPreTrainedModel): if self.last_layer_subset: assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output' use_xentropy = getattr(config, 'use_xentropy', False) - if use_xentropy and CrossEntropyLossApex is None: + if use_xentropy and CrossEntropyLoss is None: raise ImportError('xentropy_cuda is not installed') loss_cls = (nn.CrossEntropyLoss if not use_xentropy - else partial(CrossEntropyLossApex, inplace_backward=True)) + else partial(CrossEntropyLoss, inplace_backward=True)) self.bert = BertModel(config) self.cls = BertPreTrainingHeads(config) diff --git a/tests/losses/test_cross_entropy_apex.py b/tests/losses/test_cross_entropy.py similarity index 77% rename from tests/losses/test_cross_entropy_apex.py rename to tests/losses/test_cross_entropy.py index 646b539..76a1eec 100644 --- a/tests/losses/test_cross_entropy_apex.py +++ b/tests/losses/test_cross_entropy.py @@ -6,7 +6,7 @@ import pytest from einops import rearrange -from flass_attn.losses.cross_entropy_apex import CrossEntropyLossApex +from flash_attn.losses.cross_entropy import CrossEntropyLossApex is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 @@ -15,8 +15,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('inplace_backward', [False, True]) # @pytest.mark.parametrize('inplace_backward', [False]) +@pytest.mark.parametrize('smoothing', [0.0, 0.9]) @pytest.mark.parametrize('vocab_size', [50257]) -def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype): +def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype): device = 'cuda' rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) # set seed @@ -27,8 +28,8 @@ def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype): x = x_pt.detach().clone().requires_grad_() y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) y[torch.randperm(batch_size * seqlen)[:10]] = -100 - model_pt = torch.nn.CrossEntropyLoss() - model = CrossEntropyLossApex(inplace_backward=inplace_backward) + model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing) + model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward) out = model(x, y) out_pt = model_pt(x_pt.float(), y) assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) diff --git a/tests/losses/test_cross_entropy_parallel.py b/tests/losses/test_cross_entropy_parallel.py index 1f9ef45..a01284e 100644 --- a/tests/losses/test_cross_entropy_parallel.py +++ b/tests/losses/test_cross_entropy_parallel.py @@ -10,19 +10,21 @@ import pytest from apex.transformer import parallel_state from apex.transformer import tensor_parallel -from flash_attn.losses.cross_entropy_parallel import CrossEntropyLossParallel +from flash_attn.losses.cross_entropy import CrossEntropyLoss is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 @pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) -# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +# @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('inplace_backward', [False, True]) # @pytest.mark.parametrize('inplace_backward', [False]) +@pytest.mark.parametrize('smoothing', [0.0, 0.9]) +# @pytest.mark.parametrize('smoothing', [0.9]) @pytest.mark.parametrize('vocab_size', [50264]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) -def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype): +def test_cross_entropy_loss_apex(vocab_size, world_size, smoothing, inplace_backward, dtype): assert vocab_size % world_size == 0 rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32 else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))) @@ -42,9 +44,10 @@ def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_() y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) y[torch.randperm(batch_size * seqlen)[:10]] = -100 - model_pt = torch.nn.CrossEntropyLoss(reduction='none') - model = CrossEntropyLossParallel(reduction='none', inplace_backward=inplace_backward) - out = model(x, y) + model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none') + model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none', + inplace_backward=inplace_backward) + out = model(x, y, process_group=parallel_state.get_tensor_model_parallel_group()) out_pt = model_pt(x_pt.float(), y) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) diff --git a/training/configs/experiment/owt/base.yaml b/training/configs/experiment/owt/base.yaml index 988e186..801f7e4 100644 --- a/training/configs/experiment/owt/base.yaml +++ b/training/configs/experiment/owt/base.yaml @@ -54,7 +54,7 @@ train: loss_fn: # This is faster and uses less memory than torch.nn.CrossEntropyLoss. # It's also more numerically stable if we're using DeepSpeed 16 bits. - _target_: src.losses.cross_entropy_apex.CrossEntropyLossApex + _target_: src.losses.cross_entropy.CrossEntropyLoss inplace_backward: True # to save memory eval: diff --git a/training/configs/experiment/pile/base.yaml b/training/configs/experiment/pile/base.yaml index ce46efd..a509c13 100644 --- a/training/configs/experiment/pile/base.yaml +++ b/training/configs/experiment/pile/base.yaml @@ -54,7 +54,7 @@ train: loss_fn: # This is faster and uses less memory than torch.nn.CrossEntropyLoss. # It's also more numerically stable if we're using DeepSpeed 16 bits. - _target_: src.losses.cross_entropy_apex.CrossEntropyLossApex + _target_: src.losses.cross_entropy.CrossEntropyLoss inplace_backward: True # to save memory eval: diff --git a/training/src/losses/cross_entropy.py b/training/src/losses/cross_entropy.py new file mode 100644 index 0000000..48e2f2f --- /dev/null +++ b/training/src/losses/cross_entropy.py @@ -0,0 +1,128 @@ +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py +# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and +# the losses we can get the global loss. There's no need to do it step by step +# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) +# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py +import torch +import torch.nn as nn + +import xentropy_cuda_lib + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False, + process_group=None): + """ + logits: (batch, vocab_size) + labels: (batch,) + If process_group is not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss needs to be aggregated across processes. + """ + batch, vocab_size = logits.shape + assert labels.shape == (batch,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + ctx.total_classes = world_size * vocab_size + + if world_size == 1: + losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) + losses.masked_fill_(labels==ignored_index, 0) + labels_local = labels + else: + rank = torch.distributed.get_rank(process_group) + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) + + # For tensor parallel cross entropy with smoothing, we want to pass in the total number + # of classes so that smoothing can be applied correctly. If total_classes=-1, use the + # last dimension of the input tensor. + losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing, + world_size * vocab_size) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(ignored_mask, 0) + # For labels == ignored_index, the loss is always 0. + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # lse_local - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) + # For labels not in the vocab of this partition, losses contains + # 0.1 * (lse_local - sum logit / total_classes). + + lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, + device=lse_local.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), + group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + # If there's no smoothing, the total losses are lse_local - predicted_logit, + # we just have to subtract the lse_local and add the lse (global). + # If there's smoothing=0.1, the total losses are + # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) + # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). + rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor') + lse_local = lse_allgather[rank_per_sample, + torch.arange(batch, device=lse_allgather.device)] + + handle_losses.wait() + if smoothing == 0.0: + losses += lse - lse_local + else: + losses += ((1 - smoothing) * (lse - lse_local) + + smoothing * (lse - lse_allgather.sum(dim=0))) + losses.masked_fill_(ignored_mask, 0) + + ctx.save_for_backward(logits, lse, labels_local) + ctx.smoothing = smoothing + ctx.ignored_index = ignored_index + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits, lse, labels = ctx.saved_tensors + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels==ctx.ignored_index, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, + ctx.smoothing, ctx.inplace_backward, + ctx.total_classes) + return grad_logits, None, None, None, None, None, None + + +class CrossEntropyLoss(nn.Module): + + def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, + inplace_backward=False): + super().__init__() + if reduction not in ['mean', 'none']: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + + def forward(self, input, target, process_group=None): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply( + input, target, self.label_smoothing, self.ignore_index, self.inplace_backward, + process_group + ) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/training/src/losses/cross_entropy_apex.py b/training/src/losses/cross_entropy_apex.py deleted file mode 100644 index ef70946..0000000 --- a/training/src/losses/cross_entropy_apex.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import torch.nn as nn - -import xentropy_cuda_lib - - -# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py -class SoftmaxCrossEntropyLossFn(torch.autograd.Function): - @staticmethod - def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False): - losses, max_log_sum_exp = xentropy_cuda_lib.forward( - logits, labels, smoothing) - losses.masked_fill_(labels==padding_idx, 0) - ctx.save_for_backward(logits, max_log_sum_exp, labels) - ctx.smoothing = smoothing - ctx.padding_idx = padding_idx - ctx.inplace_backward = inplace_backward - return losses - - @staticmethod - def backward(ctx, grad_loss): - logits, max_log_sum_exp, labels = ctx.saved_tensors - if not grad_loss.is_contiguous(): - grad_loss = grad_loss.contiguous() - grad_loss.masked_fill_(labels==ctx.padding_idx, 0) - grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels, - ctx.smoothing, ctx.inplace_backward) - return grad_logits, None, None, None, None - - -class CrossEntropyLossApex(nn.Module): - - def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, - inplace_backward=False): - super().__init__() - if reduction not in ['mean', 'none']: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.inplace_backward = inplace_backward - - def forward(self, input, target): - assert input.is_cuda and target.is_cuda - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing, - self.ignore_index, self.inplace_backward) - if self.reduction == 'mean': - return loss.sum() / (target != self.ignore_index).sum() - else: - return loss diff --git a/training/src/losses/cross_entropy_parallel.py b/training/src/losses/cross_entropy_parallel.py deleted file mode 100644 index 84fe82d..0000000 --- a/training/src/losses/cross_entropy_parallel.py +++ /dev/null @@ -1,112 +0,0 @@ -# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py -# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and -# the losses we can get the global loss. There's no need to do it step by step -# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) -import torch -import torch.nn as nn - -import xentropy_cuda_lib - -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.parallel_state import get_tensor_model_parallel_world_size -from apex.transformer.tensor_parallel.utils import VocabUtility - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 4 lines are for backward comparability with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base -if "reduce_scatter_tensor" not in dir(torch.distributed): - torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base - - -class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function): - - @staticmethod - def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100, - inplace_backward=False): - """ - logits_parallel: (batch, vocab_size / world_size) - labels: (batch,) - """ - assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it' - batch, partition_vocab_size = logits_parallel.shape - assert labels.shape == (batch,) - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( - partition_vocab_size, get_tensor_model_parallel_rank(), - get_tensor_model_parallel_world_size() - ) - - # Create a mask of valid vocab ids (1 means it needs to be masked). - labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) - ignored_mask = labels == ignored_index - labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) - masked_labels = labels_local.clone() - masked_labels[labels_mask] = ignored_index - - losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing) - assert lse_local.shape == (batch,) - assert losses.shape == (batch,) - losses.masked_fill_(masked_labels==ignored_index, 0) - - if world_size > 1: - lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, - device=lse_local.device) - torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), - group=get_tensor_model_parallel_group()) - lse = torch.logsumexp(lse_allgather, dim=0) - torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, - group=get_tensor_model_parallel_group()) - # The losses are currently lse_local - predicted_logit, we just have to subtract the - # lse_local and add the lse (global). - rank_per_sample = labels // partition_vocab_size - lse_local = lse_allgather[rank_per_sample, - torch.arange(batch, device=lse_allgather.device)] - losses += lse - lse_local - losses.masked_fill_(ignored_mask, 0) - else: - lse = lse_local - - ctx.save_for_backward(logits_parallel, lse, labels_local) - ctx.smoothing = smoothing - ctx.ignored_index = ignored_index - ctx.inplace_backward = inplace_backward - return losses - - @staticmethod - def backward(ctx, grad_loss): - logits_parallel, lse, labels = ctx.saved_tensors - if not grad_loss.is_contiguous(): - grad_loss = grad_loss.contiguous() - grad_loss.masked_fill_(labels==ctx.ignored_index, 0) - grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels, - ctx.smoothing, ctx.inplace_backward) - return grad_logits, None, None, None, None, None - - -class CrossEntropyLossParallel(nn.Module): - - def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, - inplace_backward=False): - super().__init__() - if reduction not in ['mean', 'none']: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.inplace_backward = inplace_backward - - def forward(self, input, target): - assert input.is_cuda and target.is_cuda - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = SoftmaxCrossEntropyLossParallelFn.apply( - input, target, self.label_smoothing, self.ignore_index, self.inplace_backward - ) - if self.reduction == 'mean': - return loss.sum() / (target != self.ignore_index).sum() - else: - return loss diff --git a/training/src/metrics/perplexity.py b/training/src/metrics/perplexity.py index 9e79a4b..9570f07 100644 --- a/training/src/metrics/perplexity.py +++ b/training/src/metrics/perplexity.py @@ -11,7 +11,7 @@ from torch import Tensor from torchmetrics import Metric try: - from src.losses.cross_entropy_apex import CrossEntropyLossApex as CrossEntropyLoss + from src.losses.cross_entropy import CrossEntropyLoss except ImportError: CrossEntropyLoss = torch.nn.CrossEntropyLoss