diff --git a/training/configs/experiment/owt/base.yaml b/training/configs/experiment/owt/base.yaml index 801f7e4..48de7a2 100644 --- a/training/configs/experiment/owt/base.yaml +++ b/training/configs/experiment/owt/base.yaml @@ -54,7 +54,7 @@ train: loss_fn: # This is faster and uses less memory than torch.nn.CrossEntropyLoss. # It's also more numerically stable if we're using DeepSpeed 16 bits. - _target_: src.losses.cross_entropy.CrossEntropyLoss + _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss inplace_backward: True # to save memory eval: diff --git a/training/configs/experiment/pile/base.yaml b/training/configs/experiment/pile/base.yaml index a509c13..293aa91 100644 --- a/training/configs/experiment/pile/base.yaml +++ b/training/configs/experiment/pile/base.yaml @@ -54,7 +54,7 @@ train: loss_fn: # This is faster and uses less memory than torch.nn.CrossEntropyLoss. # It's also more numerically stable if we're using DeepSpeed 16 bits. - _target_: src.losses.cross_entropy.CrossEntropyLoss + _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss inplace_backward: True # to save memory eval: diff --git a/training/src/losses/cross_entropy.py b/training/src/losses/cross_entropy.py deleted file mode 100644 index bc2df84..0000000 --- a/training/src/losses/cross_entropy.py +++ /dev/null @@ -1,129 +0,0 @@ -# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py -# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and -# the losses we can get the global loss. There's no need to do it step by step -# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) -# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py -import torch -import torch.nn as nn - -import xentropy_cuda_lib - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 2 lines are for backward compatibility with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - - -class SoftmaxCrossEntropyLossFn(torch.autograd.Function): - - @staticmethod - def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False, - process_group=None): - """ - logits: (batch, vocab_size) - labels: (batch,) - If process_group is not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss needs to be aggregated across processes. - """ - batch, vocab_size = logits.shape - assert labels.shape == (batch,) - world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) - ctx.total_classes = world_size * vocab_size - - if world_size == 1: - losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) - losses.masked_fill_(labels==ignored_index, 0) - labels_local = labels - else: - rank = torch.distributed.get_rank(process_group) - vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size - - # Create a mask of valid vocab ids (1 means it needs to be masked). - labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) - ignored_mask = labels == ignored_index - labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) - - # For tensor parallel cross entropy with smoothing, we want to pass in the total number - # of classes so that smoothing can be applied correctly. If total_classes=-1, use the - # last dimension of the input tensor. - losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing, - world_size * vocab_size) - assert lse_local.shape == (batch,) - assert losses.shape == (batch,) - losses.masked_fill_(ignored_mask, 0) - # For labels == ignored_index, the loss is always 0. - # If there's no smoothing, if labels are in the vocab of this partition, losses contains - # lse_local - predicted logit, and 0 otherwise. - # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains - # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) - # For labels not in the vocab of this partition, losses contains - # 0.1 * (lse_local - sum logit / total_classes). - - lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, - device=lse_local.device) - torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), - group=process_group) - handle_losses = torch.distributed.all_reduce( - losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True - ) - lse = torch.logsumexp(lse_allgather, dim=0) - # If there's no smoothing, the total losses are lse_local - predicted_logit, - # we just have to subtract the lse_local and add the lse (global). - # If there's smoothing=0.1, the total losses are - # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) - # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). - rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor') - lse_local = lse_allgather[rank_per_sample, - torch.arange(batch, device=lse_allgather.device)] - - handle_losses.wait() - if smoothing == 0.0: - losses += lse - lse_local - else: - losses += ((1 - smoothing) * (lse - lse_local) - + smoothing * (lse - lse_allgather.sum(dim=0))) - losses.masked_fill_(ignored_mask, 0) - - ctx.save_for_backward(logits, lse, labels_local) - ctx.smoothing = smoothing - ctx.ignored_index = ignored_index - ctx.inplace_backward = inplace_backward - return losses - - @staticmethod - def backward(ctx, grad_loss): - logits, lse, labels = ctx.saved_tensors - grad_loss = grad_loss.contiguous() - grad_loss.masked_fill_(labels==ctx.ignored_index, 0) - grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, - ctx.smoothing, ctx.inplace_backward, - ctx.total_classes) - return grad_logits, None, None, None, None, None, None - - -class CrossEntropyLoss(nn.Module): - - def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, - inplace_backward=False, process_group=None): - super().__init__() - if reduction not in ['mean', 'none']: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.inplace_backward = inplace_backward - self.process_group = process_group - - def forward(self, input, target): - assert input.is_cuda and target.is_cuda - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = SoftmaxCrossEntropyLossFn.apply( - input, target, self.label_smoothing, self.ignore_index, self.inplace_backward, - self.process_group - ) - if self.reduction == 'mean': - return loss.sum() / (target != self.ignore_index).sum() - else: - return loss diff --git a/training/src/metrics/perplexity.py b/training/src/metrics/perplexity.py index 9570f07..b90cf21 100644 --- a/training/src/metrics/perplexity.py +++ b/training/src/metrics/perplexity.py @@ -11,7 +11,7 @@ from torch import Tensor from torchmetrics import Metric try: - from src.losses.cross_entropy import CrossEntropyLoss + from flash_attn.losses.cross_entropy import CrossEntropyLoss except ImportError: CrossEntropyLoss = torch.nn.CrossEntropyLoss