Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss
This commit is contained in:
parent
e68ebbe89a
commit
dff68c2b22
@ -4,7 +4,8 @@
|
||||
std::vector<at::Tensor> softmax_xentropy_cuda(
|
||||
const at::Tensor &input,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing);
|
||||
const float smoothing,
|
||||
const int total_classes);
|
||||
|
||||
at::Tensor softmax_xentropy_backward_cuda(
|
||||
const at::Tensor &grad_loss,
|
||||
@ -12,7 +13,8 @@ at::Tensor softmax_xentropy_backward_cuda(
|
||||
const at::Tensor &max_log_sum_exp,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing,
|
||||
const bool inplace);
|
||||
const bool inplace,
|
||||
const int total_classes);
|
||||
|
||||
// C++ interface
|
||||
|
||||
@ -23,11 +25,15 @@ at::Tensor softmax_xentropy_backward_cuda(
|
||||
std::vector<at::Tensor> softmax_xentropy_forward(
|
||||
const at::Tensor &input,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing) {
|
||||
CHECK_CUDA(input);
|
||||
const float smoothing,
|
||||
const int total_classes=-1) {
|
||||
// For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
||||
// of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
||||
// last dimension of the input tensor.
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(labels);
|
||||
|
||||
return softmax_xentropy_cuda(input, labels, smoothing);
|
||||
return softmax_xentropy_cuda(input, labels, smoothing, total_classes);
|
||||
}
|
||||
|
||||
at::Tensor softmax_xentropy_backward(
|
||||
@ -36,16 +42,18 @@ at::Tensor softmax_xentropy_backward(
|
||||
const at::Tensor &max_log_sum_exp,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing,
|
||||
const bool inplace) {
|
||||
CHECK_CUDA(grad_loss);
|
||||
CHECK_CUDA(logits);
|
||||
const bool inplace,
|
||||
const int total_classes=-1) {
|
||||
CHECK_INPUT(grad_loss);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_INPUT(max_log_sum_exp);
|
||||
CHECK_INPUT(labels);
|
||||
|
||||
return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
|
||||
return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels,
|
||||
smoothing, inplace, total_classes);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)");
|
||||
m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)");
|
||||
m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1);
|
||||
m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1);
|
||||
}
|
||||
|
||||
@ -434,7 +434,8 @@ cunn_SoftMaxXEntropyForward(
|
||||
scalar_t *input,
|
||||
int64_t *labels,
|
||||
int64_t classes,
|
||||
const float smoothing)
|
||||
const float smoothing,
|
||||
const int total_classes)
|
||||
{
|
||||
extern __shared__ unsigned char smem[];
|
||||
auto sdata = reinterpret_cast<accscalar_t*>(smem);
|
||||
@ -472,12 +473,8 @@ cunn_SoftMaxXEntropyForward(
|
||||
// reserve max + log_sum_exp for bprop
|
||||
if (threadIdx.x == 0) {
|
||||
accscalar_t lse = max_k + std::log(sumAll);
|
||||
if ((label >= 0) && (label < classes)) {
|
||||
accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));
|
||||
losses[blockIdx.x] = (lse - sum_k / classes) * smoothing - log_prob * (1 - smoothing);
|
||||
} else {
|
||||
losses[blockIdx.x] = outscalar_t(0.f);
|
||||
}
|
||||
accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast<accscalar_t>(input[label])) : 0.f;
|
||||
losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing);
|
||||
max_log_sum_exp[blockIdx.x] = lse;
|
||||
}
|
||||
}
|
||||
@ -490,10 +487,11 @@ apply(scalar_t *gradInput,
|
||||
outscalar_t *gradOutput,
|
||||
int64_t *labels,
|
||||
const float smoothing,
|
||||
int classes)
|
||||
int classes,
|
||||
const int total_classes)
|
||||
{
|
||||
accscalar_t smooth_positives = 1.0 - smoothing;
|
||||
accscalar_t smooth_negatives = smoothing / classes;
|
||||
accscalar_t smooth_negatives = smoothing / total_classes;
|
||||
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
|
||||
int64_t label = labels[blockIdx.x];
|
||||
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
|
||||
@ -534,10 +532,11 @@ aligned_apply(int shift,
|
||||
outscalar_t *gradOutput,
|
||||
int64_t *labels,
|
||||
const float smoothing,
|
||||
int classes)
|
||||
int classes,
|
||||
const int total_classes)
|
||||
{
|
||||
accscalar_t smooth_positives = 1.0 - smoothing;
|
||||
accscalar_t smooth_negatives = smoothing / classes;
|
||||
accscalar_t smooth_negatives = smoothing / total_classes;
|
||||
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
|
||||
int64_t label = labels[blockIdx.x];
|
||||
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
|
||||
@ -602,7 +601,8 @@ cunn_SoftMaxXEntropyBackward(
|
||||
outscalar_t *gradOutput,
|
||||
int64_t *labels,
|
||||
const float smoothing,
|
||||
int classes)
|
||||
int classes,
|
||||
const int total_classes)
|
||||
{
|
||||
gradInput += blockIdx.x * classes;
|
||||
logits += blockIdx.x * classes;
|
||||
@ -611,10 +611,10 @@ cunn_SoftMaxXEntropyBackward(
|
||||
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
|
||||
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
|
||||
if (shift == shift_){
|
||||
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
|
||||
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
|
||||
}
|
||||
else {
|
||||
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
|
||||
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
|
||||
}
|
||||
|
||||
}
|
||||
@ -623,7 +623,11 @@ template<template<typename, typename, typename> class Epilogue>
|
||||
std::vector<Tensor> host_softmax_xentropy(
|
||||
const Tensor & input_,
|
||||
const Tensor & labels_,
|
||||
const float smoothing){
|
||||
const float smoothing,
|
||||
const int total_classes) {
|
||||
// For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
||||
// of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
||||
// last dimension of the input tensor.
|
||||
AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
@ -666,7 +670,7 @@ std::vector<Tensor> host_softmax_xentropy(
|
||||
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
|
||||
losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
|
||||
input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
|
||||
dim_size, smoothing
|
||||
dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes
|
||||
);
|
||||
);
|
||||
|
||||
@ -683,7 +687,8 @@ Tensor host_softmax_xentropy_backward(
|
||||
const at::Tensor &max_log_sum_exp,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing,
|
||||
bool inplace) {
|
||||
bool inplace,
|
||||
const int total_classes) {
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()};
|
||||
@ -730,7 +735,7 @@ Tensor host_softmax_xentropy_backward(
|
||||
gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
|
||||
max_log_sum_exp.data_ptr<accscalar_t>(),
|
||||
grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
|
||||
smoothing, dim_size
|
||||
smoothing, dim_size, total_classes
|
||||
);
|
||||
);
|
||||
|
||||
@ -738,8 +743,8 @@ Tensor host_softmax_xentropy_backward(
|
||||
return gI;
|
||||
}
|
||||
|
||||
std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing){
|
||||
return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing);
|
||||
std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){
|
||||
return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, total_classes);
|
||||
}
|
||||
|
||||
at::Tensor softmax_xentropy_backward_cuda(
|
||||
@ -748,7 +753,8 @@ at::Tensor softmax_xentropy_backward_cuda(
|
||||
const at::Tensor &max_log_sum_exp,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing,
|
||||
const bool inplace) {
|
||||
const bool inplace,
|
||||
const int total_classes) {
|
||||
AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
|
||||
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
|
||||
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes);
|
||||
}
|
||||
|
||||
128
flash_attn/losses/cross_entropy.py
Normal file
128
flash_attn/losses/cross_entropy.py
Normal file
@ -0,0 +1,128 @@
|
||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
||||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
||||
# the losses we can get the global loss. There's no need to do it step by step
|
||||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
||||
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 2 lines are for backward compatibility with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
|
||||
process_group=None):
|
||||
"""
|
||||
logits: (batch, vocab_size)
|
||||
labels: (batch,)
|
||||
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
|
||||
one part of the vocab. The loss needs to be aggregated across processes.
|
||||
"""
|
||||
batch, vocab_size = logits.shape
|
||||
assert labels.shape == (batch,)
|
||||
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
||||
ctx.total_classes = world_size * vocab_size
|
||||
|
||||
if world_size == 1:
|
||||
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
|
||||
losses.masked_fill_(labels==ignored_index, 0)
|
||||
labels_local = labels
|
||||
else:
|
||||
rank = torch.distributed.get_rank(process_group)
|
||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
||||
ignored_mask = labels == ignored_index
|
||||
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
||||
|
||||
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
||||
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
||||
# last dimension of the input tensor.
|
||||
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
|
||||
world_size * vocab_size)
|
||||
assert lse_local.shape == (batch,)
|
||||
assert losses.shape == (batch,)
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
# For labels == ignored_index, the loss is always 0.
|
||||
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
||||
# lse_local - predicted logit, and 0 otherwise.
|
||||
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
||||
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
|
||||
# For labels not in the vocab of this partition, losses contains
|
||||
# 0.1 * (lse_local - sum logit / total_classes).
|
||||
|
||||
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
||||
device=lse_local.device)
|
||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
|
||||
group=process_group)
|
||||
handle_losses = torch.distributed.all_reduce(
|
||||
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
||||
)
|
||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
||||
# If there's no smoothing, the total losses are lse_local - predicted_logit,
|
||||
# we just have to subtract the lse_local and add the lse (global).
|
||||
# If there's smoothing=0.1, the total losses are
|
||||
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
|
||||
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
|
||||
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
|
||||
lse_local = lse_allgather[rank_per_sample,
|
||||
torch.arange(batch, device=lse_allgather.device)]
|
||||
|
||||
handle_losses.wait()
|
||||
if smoothing == 0.0:
|
||||
losses += lse - lse_local
|
||||
else:
|
||||
losses += ((1 - smoothing) * (lse - lse_local)
|
||||
+ smoothing * (lse - lse_allgather.sum(dim=0)))
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
|
||||
ctx.save_for_backward(logits, lse, labels_local)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits, lse, labels = ctx.saved_tensors
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
|
||||
ctx.smoothing, ctx.inplace_backward,
|
||||
ctx.total_classes)
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
|
||||
def forward(self, input, target, process_group=None):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossFn.apply(
|
||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
|
||||
process_group
|
||||
)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
@ -1,51 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
|
||||
losses, max_log_sum_exp = xentropy_cuda_lib.forward(
|
||||
logits, labels, smoothing)
|
||||
losses.masked_fill_(labels==padding_idx, 0)
|
||||
ctx.save_for_backward(logits, max_log_sum_exp, labels)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.padding_idx = padding_idx
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits, max_log_sum_exp, labels = ctx.saved_tensors
|
||||
if not grad_loss.is_contiguous():
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
|
||||
ctx.smoothing, ctx.inplace_backward)
|
||||
return grad_logits, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLossApex(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
|
||||
def forward(self, input, target):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
|
||||
self.ignore_index, self.inplace_backward)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
@ -1,122 +0,0 @@
|
||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
||||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
||||
# the losses we can get the global loss. There's no need to do it step by step
|
||||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_group
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
|
||||
from apex.transformer.tensor_parallel.utils import VocabUtility
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 4 lines are for backward compatibility with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
|
||||
inplace_backward=False):
|
||||
"""
|
||||
logits_parallel: (batch, vocab_size / world_size)
|
||||
labels: (batch,)
|
||||
"""
|
||||
assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
|
||||
batch, partition_vocab_size = logits_parallel.shape
|
||||
assert labels.shape == (batch,)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if world_size == 1:
|
||||
losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing)
|
||||
losses.masked_fill_(labels==ignored_index, 0)
|
||||
labels_local = labels
|
||||
else:
|
||||
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
partition_vocab_size, get_tensor_model_parallel_rank(),
|
||||
get_tensor_model_parallel_world_size()
|
||||
)
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
||||
ignored_mask = labels == ignored_index
|
||||
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
||||
masked_labels = labels_local.clone()
|
||||
masked_labels[labels_mask] = ignored_index
|
||||
|
||||
losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
|
||||
assert lse_local.shape == (batch,)
|
||||
assert losses.shape == (batch,)
|
||||
losses.masked_fill_(masked_labels==ignored_index, 0)
|
||||
|
||||
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
||||
device=lse_local.device)
|
||||
handle_lse = torch.distributed.all_gather_into_tensor(
|
||||
lse_allgather, lse_local.contiguous(),
|
||||
group=get_tensor_model_parallel_group(), async_op=True
|
||||
)
|
||||
handle_losses = torch.distributed.all_reduce(
|
||||
losses, op=torch.distributed.ReduceOp.SUM,
|
||||
group=get_tensor_model_parallel_group(), async_op=True
|
||||
)
|
||||
handle_lse.wait()
|
||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
||||
# The losses are going to be lse_local - predicted_logit, we just have to subtract
|
||||
# the lse_local and add the lse (global).
|
||||
rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode='floor')
|
||||
lse_local = lse_allgather[rank_per_sample,
|
||||
torch.arange(batch, device=lse_allgather.device)]
|
||||
|
||||
handle_losses.wait()
|
||||
losses += lse - lse_local
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
|
||||
ctx.save_for_backward(logits_parallel, lse, labels_local)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits_parallel, lse, labels = ctx.saved_tensors
|
||||
if not grad_loss.is_contiguous():
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
|
||||
ctx.smoothing, ctx.inplace_backward)
|
||||
return grad_logits, None, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLossParallel(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
|
||||
def forward(self, input, target):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossParallelFn.apply(
|
||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
|
||||
)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
@ -40,9 +40,9 @@ except ImportError:
|
||||
dropout_add_layer_norm, layer_norm = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy_apex import CrossEntropyLossApex
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
except ImportError:
|
||||
CrossEntropyLossApex = None
|
||||
CrossEntropyLoss = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -374,10 +374,10 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
if self.last_layer_subset:
|
||||
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
|
||||
use_xentropy = getattr(config, 'use_xentropy', False)
|
||||
if use_xentropy and CrossEntropyLossApex is None:
|
||||
if use_xentropy and CrossEntropyLoss is None:
|
||||
raise ImportError('xentropy_cuda is not installed')
|
||||
loss_cls = (nn.CrossEntropyLoss if not use_xentropy
|
||||
else partial(CrossEntropyLossApex, inplace_backward=True))
|
||||
else partial(CrossEntropyLoss, inplace_backward=True))
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertPreTrainingHeads(config)
|
||||
|
||||
@ -6,7 +6,7 @@ import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flass_attn.losses.cross_entropy_apex import CrossEntropyLossApex
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLossApex
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
@ -15,8 +15,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('inplace_backward', [False, True])
|
||||
# @pytest.mark.parametrize('inplace_backward', [False])
|
||||
@pytest.mark.parametrize('smoothing', [0.0, 0.9])
|
||||
@pytest.mark.parametrize('vocab_size', [50257])
|
||||
def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype):
|
||||
def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
|
||||
device = 'cuda'
|
||||
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
||||
# set seed
|
||||
@ -27,8 +28,8 @@ def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype):
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
|
||||
y[torch.randperm(batch_size * seqlen)[:10]] = -100
|
||||
model_pt = torch.nn.CrossEntropyLoss()
|
||||
model = CrossEntropyLossApex(inplace_backward=inplace_backward)
|
||||
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
|
||||
model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward)
|
||||
out = model(x, y)
|
||||
out_pt = model_pt(x_pt.float(), y)
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
|
||||
@ -10,19 +10,21 @@ import pytest
|
||||
from apex.transformer import parallel_state
|
||||
from apex.transformer import tensor_parallel
|
||||
|
||||
from flash_attn.losses.cross_entropy_parallel import CrossEntropyLossParallel
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('inplace_backward', [False, True])
|
||||
# @pytest.mark.parametrize('inplace_backward', [False])
|
||||
@pytest.mark.parametrize('smoothing', [0.0, 0.9])
|
||||
# @pytest.mark.parametrize('smoothing', [0.9])
|
||||
@pytest.mark.parametrize('vocab_size', [50264])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype):
|
||||
def test_cross_entropy_loss_apex(vocab_size, world_size, smoothing, inplace_backward, dtype):
|
||||
assert vocab_size % world_size == 0
|
||||
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
|
||||
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
|
||||
@ -42,9 +44,10 @@ def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype
|
||||
x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
|
||||
y[torch.randperm(batch_size * seqlen)[:10]] = -100
|
||||
model_pt = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
model = CrossEntropyLossParallel(reduction='none', inplace_backward=inplace_backward)
|
||||
out = model(x, y)
|
||||
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none')
|
||||
model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none',
|
||||
inplace_backward=inplace_backward)
|
||||
out = model(x, y, process_group=parallel_state.get_tensor_model_parallel_group())
|
||||
out_pt = model_pt(x_pt.float(), y)
|
||||
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ train:
|
||||
loss_fn:
|
||||
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
|
||||
# It's also more numerically stable if we're using DeepSpeed 16 bits.
|
||||
_target_: src.losses.cross_entropy_apex.CrossEntropyLossApex
|
||||
_target_: src.losses.cross_entropy.CrossEntropyLoss
|
||||
inplace_backward: True # to save memory
|
||||
|
||||
eval:
|
||||
|
||||
@ -54,7 +54,7 @@ train:
|
||||
loss_fn:
|
||||
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
|
||||
# It's also more numerically stable if we're using DeepSpeed 16 bits.
|
||||
_target_: src.losses.cross_entropy_apex.CrossEntropyLossApex
|
||||
_target_: src.losses.cross_entropy.CrossEntropyLoss
|
||||
inplace_backward: True # to save memory
|
||||
|
||||
eval:
|
||||
|
||||
128
training/src/losses/cross_entropy.py
Normal file
128
training/src/losses/cross_entropy.py
Normal file
@ -0,0 +1,128 @@
|
||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
||||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
||||
# the losses we can get the global loss. There's no need to do it step by step
|
||||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
||||
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 2 lines are for backward compatibility with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
|
||||
process_group=None):
|
||||
"""
|
||||
logits: (batch, vocab_size)
|
||||
labels: (batch,)
|
||||
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
|
||||
one part of the vocab. The loss needs to be aggregated across processes.
|
||||
"""
|
||||
batch, vocab_size = logits.shape
|
||||
assert labels.shape == (batch,)
|
||||
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
||||
ctx.total_classes = world_size * vocab_size
|
||||
|
||||
if world_size == 1:
|
||||
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
|
||||
losses.masked_fill_(labels==ignored_index, 0)
|
||||
labels_local = labels
|
||||
else:
|
||||
rank = torch.distributed.get_rank(process_group)
|
||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
||||
ignored_mask = labels == ignored_index
|
||||
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
||||
|
||||
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
||||
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
||||
# last dimension of the input tensor.
|
||||
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
|
||||
world_size * vocab_size)
|
||||
assert lse_local.shape == (batch,)
|
||||
assert losses.shape == (batch,)
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
# For labels == ignored_index, the loss is always 0.
|
||||
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
||||
# lse_local - predicted logit, and 0 otherwise.
|
||||
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
||||
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
|
||||
# For labels not in the vocab of this partition, losses contains
|
||||
# 0.1 * (lse_local - sum logit / total_classes).
|
||||
|
||||
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
||||
device=lse_local.device)
|
||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
|
||||
group=process_group)
|
||||
handle_losses = torch.distributed.all_reduce(
|
||||
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
||||
)
|
||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
||||
# If there's no smoothing, the total losses are lse_local - predicted_logit,
|
||||
# we just have to subtract the lse_local and add the lse (global).
|
||||
# If there's smoothing=0.1, the total losses are
|
||||
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
|
||||
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
|
||||
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
|
||||
lse_local = lse_allgather[rank_per_sample,
|
||||
torch.arange(batch, device=lse_allgather.device)]
|
||||
|
||||
handle_losses.wait()
|
||||
if smoothing == 0.0:
|
||||
losses += lse - lse_local
|
||||
else:
|
||||
losses += ((1 - smoothing) * (lse - lse_local)
|
||||
+ smoothing * (lse - lse_allgather.sum(dim=0)))
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
|
||||
ctx.save_for_backward(logits, lse, labels_local)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits, lse, labels = ctx.saved_tensors
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
|
||||
ctx.smoothing, ctx.inplace_backward,
|
||||
ctx.total_classes)
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
|
||||
def forward(self, input, target, process_group=None):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossFn.apply(
|
||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
|
||||
process_group
|
||||
)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
@ -1,51 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
|
||||
losses, max_log_sum_exp = xentropy_cuda_lib.forward(
|
||||
logits, labels, smoothing)
|
||||
losses.masked_fill_(labels==padding_idx, 0)
|
||||
ctx.save_for_backward(logits, max_log_sum_exp, labels)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.padding_idx = padding_idx
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits, max_log_sum_exp, labels = ctx.saved_tensors
|
||||
if not grad_loss.is_contiguous():
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
|
||||
ctx.smoothing, ctx.inplace_backward)
|
||||
return grad_logits, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLossApex(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
|
||||
def forward(self, input, target):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
|
||||
self.ignore_index, self.inplace_backward)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
@ -1,112 +0,0 @@
|
||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
||||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
||||
# the losses we can get the global loss. There's no need to do it step by step
|
||||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_group
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
|
||||
from apex.transformer.tensor_parallel.utils import VocabUtility
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 4 lines are for backward comparability with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
|
||||
inplace_backward=False):
|
||||
"""
|
||||
logits_parallel: (batch, vocab_size / world_size)
|
||||
labels: (batch,)
|
||||
"""
|
||||
assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
|
||||
batch, partition_vocab_size = logits_parallel.shape
|
||||
assert labels.shape == (batch,)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
partition_vocab_size, get_tensor_model_parallel_rank(),
|
||||
get_tensor_model_parallel_world_size()
|
||||
)
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
||||
ignored_mask = labels == ignored_index
|
||||
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
||||
masked_labels = labels_local.clone()
|
||||
masked_labels[labels_mask] = ignored_index
|
||||
|
||||
losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
|
||||
assert lse_local.shape == (batch,)
|
||||
assert losses.shape == (batch,)
|
||||
losses.masked_fill_(masked_labels==ignored_index, 0)
|
||||
|
||||
if world_size > 1:
|
||||
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
||||
device=lse_local.device)
|
||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
|
||||
group=get_tensor_model_parallel_group())
|
||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
||||
torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM,
|
||||
group=get_tensor_model_parallel_group())
|
||||
# The losses are currently lse_local - predicted_logit, we just have to subtract the
|
||||
# lse_local and add the lse (global).
|
||||
rank_per_sample = labels // partition_vocab_size
|
||||
lse_local = lse_allgather[rank_per_sample,
|
||||
torch.arange(batch, device=lse_allgather.device)]
|
||||
losses += lse - lse_local
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
else:
|
||||
lse = lse_local
|
||||
|
||||
ctx.save_for_backward(logits_parallel, lse, labels_local)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits_parallel, lse, labels = ctx.saved_tensors
|
||||
if not grad_loss.is_contiguous():
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
|
||||
ctx.smoothing, ctx.inplace_backward)
|
||||
return grad_logits, None, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLossParallel(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
|
||||
def forward(self, input, target):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossParallelFn.apply(
|
||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
|
||||
)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
@ -11,7 +11,7 @@ from torch import Tensor
|
||||
from torchmetrics import Metric
|
||||
|
||||
try:
|
||||
from src.losses.cross_entropy_apex import CrossEntropyLossApex as CrossEntropyLoss
|
||||
from src.losses.cross_entropy import CrossEntropyLoss
|
||||
except ImportError:
|
||||
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user