[Loss] Use flash_attn.losses.cross_entropy.CrossEntropyLoss
This commit is contained in:
parent
cadfa396b8
commit
71befc19e1
@ -54,7 +54,7 @@ train:
|
|||||||
loss_fn:
|
loss_fn:
|
||||||
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
|
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
|
||||||
# It's also more numerically stable if we're using DeepSpeed 16 bits.
|
# It's also more numerically stable if we're using DeepSpeed 16 bits.
|
||||||
_target_: src.losses.cross_entropy.CrossEntropyLoss
|
_target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
|
||||||
inplace_backward: True # to save memory
|
inplace_backward: True # to save memory
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
|
|||||||
@ -54,7 +54,7 @@ train:
|
|||||||
loss_fn:
|
loss_fn:
|
||||||
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
|
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
|
||||||
# It's also more numerically stable if we're using DeepSpeed 16 bits.
|
# It's also more numerically stable if we're using DeepSpeed 16 bits.
|
||||||
_target_: src.losses.cross_entropy.CrossEntropyLoss
|
_target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
|
||||||
inplace_backward: True # to save memory
|
inplace_backward: True # to save memory
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
|
|||||||
@ -1,129 +0,0 @@
|
|||||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
|
||||||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
|
||||||
# the losses we can get the global loss. There's no need to do it step by step
|
|
||||||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
|
||||||
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
import xentropy_cuda_lib
|
|
||||||
|
|
||||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
|
||||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
|
||||||
# version of PyTorch. The following 2 lines are for backward compatibility with
|
|
||||||
# older PyTorch.
|
|
||||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
|
||||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
|
|
||||||
process_group=None):
|
|
||||||
"""
|
|
||||||
logits: (batch, vocab_size)
|
|
||||||
labels: (batch,)
|
|
||||||
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
|
|
||||||
one part of the vocab. The loss needs to be aggregated across processes.
|
|
||||||
"""
|
|
||||||
batch, vocab_size = logits.shape
|
|
||||||
assert labels.shape == (batch,)
|
|
||||||
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
|
||||||
ctx.total_classes = world_size * vocab_size
|
|
||||||
|
|
||||||
if world_size == 1:
|
|
||||||
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
|
|
||||||
losses.masked_fill_(labels==ignored_index, 0)
|
|
||||||
labels_local = labels
|
|
||||||
else:
|
|
||||||
rank = torch.distributed.get_rank(process_group)
|
|
||||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
|
||||||
|
|
||||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
|
||||||
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
|
||||||
ignored_mask = labels == ignored_index
|
|
||||||
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
|
||||||
|
|
||||||
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
|
||||||
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
|
||||||
# last dimension of the input tensor.
|
|
||||||
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
|
|
||||||
world_size * vocab_size)
|
|
||||||
assert lse_local.shape == (batch,)
|
|
||||||
assert losses.shape == (batch,)
|
|
||||||
losses.masked_fill_(ignored_mask, 0)
|
|
||||||
# For labels == ignored_index, the loss is always 0.
|
|
||||||
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
|
||||||
# lse_local - predicted logit, and 0 otherwise.
|
|
||||||
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
|
||||||
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
|
|
||||||
# For labels not in the vocab of this partition, losses contains
|
|
||||||
# 0.1 * (lse_local - sum logit / total_classes).
|
|
||||||
|
|
||||||
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
|
||||||
device=lse_local.device)
|
|
||||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
|
|
||||||
group=process_group)
|
|
||||||
handle_losses = torch.distributed.all_reduce(
|
|
||||||
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
|
||||||
)
|
|
||||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
|
||||||
# If there's no smoothing, the total losses are lse_local - predicted_logit,
|
|
||||||
# we just have to subtract the lse_local and add the lse (global).
|
|
||||||
# If there's smoothing=0.1, the total losses are
|
|
||||||
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
|
|
||||||
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
|
|
||||||
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
|
|
||||||
lse_local = lse_allgather[rank_per_sample,
|
|
||||||
torch.arange(batch, device=lse_allgather.device)]
|
|
||||||
|
|
||||||
handle_losses.wait()
|
|
||||||
if smoothing == 0.0:
|
|
||||||
losses += lse - lse_local
|
|
||||||
else:
|
|
||||||
losses += ((1 - smoothing) * (lse - lse_local)
|
|
||||||
+ smoothing * (lse - lse_allgather.sum(dim=0)))
|
|
||||||
losses.masked_fill_(ignored_mask, 0)
|
|
||||||
|
|
||||||
ctx.save_for_backward(logits, lse, labels_local)
|
|
||||||
ctx.smoothing = smoothing
|
|
||||||
ctx.ignored_index = ignored_index
|
|
||||||
ctx.inplace_backward = inplace_backward
|
|
||||||
return losses
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_loss):
|
|
||||||
logits, lse, labels = ctx.saved_tensors
|
|
||||||
grad_loss = grad_loss.contiguous()
|
|
||||||
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
|
||||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
|
|
||||||
ctx.smoothing, ctx.inplace_backward,
|
|
||||||
ctx.total_classes)
|
|
||||||
return grad_logits, None, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
|
||||||
inplace_backward=False, process_group=None):
|
|
||||||
super().__init__()
|
|
||||||
if reduction not in ['mean', 'none']:
|
|
||||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
|
||||||
self.ignore_index = ignore_index
|
|
||||||
self.reduction = reduction
|
|
||||||
self.label_smoothing = label_smoothing
|
|
||||||
self.inplace_backward = inplace_backward
|
|
||||||
self.process_group = process_group
|
|
||||||
|
|
||||||
def forward(self, input, target):
|
|
||||||
assert input.is_cuda and target.is_cuda
|
|
||||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
|
||||||
loss = SoftmaxCrossEntropyLossFn.apply(
|
|
||||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
|
|
||||||
self.process_group
|
|
||||||
)
|
|
||||||
if self.reduction == 'mean':
|
|
||||||
return loss.sum() / (target != self.ignore_index).sum()
|
|
||||||
else:
|
|
||||||
return loss
|
|
||||||
@ -11,7 +11,7 @@ from torch import Tensor
|
|||||||
from torchmetrics import Metric
|
from torchmetrics import Metric
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.losses.cross_entropy import CrossEntropyLoss
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
except ImportError:
|
except ImportError:
|
||||||
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user