123 lines
5.7 KiB
Python
123 lines
5.7 KiB
Python
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
|
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
|
# the losses we can get the global loss. There's no need to do it step by step
|
|
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import xentropy_cuda_lib
|
|
|
|
from apex.transformer.parallel_state import get_tensor_model_parallel_group
|
|
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
|
|
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
|
|
from apex.transformer.tensor_parallel.utils import VocabUtility
|
|
|
|
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
|
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
|
# version of PyTorch. The following 4 lines are for backward comparability with
|
|
# older PyTorch.
|
|
if "all_gather_into_tensor" not in dir(torch.distributed):
|
|
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
|
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
|
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
|
|
|
|
|
class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
|
|
inplace_backward=False):
|
|
"""
|
|
logits_parallel: (batch, vocab_size / world_size)
|
|
labels: (batch,)
|
|
"""
|
|
assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
|
|
batch, partition_vocab_size = logits_parallel.shape
|
|
assert labels.shape == (batch,)
|
|
rank = get_tensor_model_parallel_rank()
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
|
|
if world_size == 1:
|
|
losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing)
|
|
losses.masked_fill_(labels==ignored_index, 0)
|
|
labels_local = labels
|
|
else:
|
|
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
|
|
partition_vocab_size, get_tensor_model_parallel_rank(),
|
|
get_tensor_model_parallel_world_size()
|
|
)
|
|
|
|
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
|
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
|
ignored_mask = labels == ignored_index
|
|
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
|
masked_labels = labels_local.clone()
|
|
masked_labels[labels_mask] = ignored_index
|
|
|
|
losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
|
|
assert lse_local.shape == (batch,)
|
|
assert losses.shape == (batch,)
|
|
losses.masked_fill_(masked_labels==ignored_index, 0)
|
|
|
|
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
|
device=lse_local.device)
|
|
handle_lse = torch.distributed.all_gather_into_tensor(
|
|
lse_allgather, lse_local.contiguous(),
|
|
group=get_tensor_model_parallel_group(), async_op=True
|
|
)
|
|
handle_losses = torch.distributed.all_reduce(
|
|
losses, op=torch.distributed.ReduceOp.SUM,
|
|
group=get_tensor_model_parallel_group(), async_op=True
|
|
)
|
|
handle_lse.wait()
|
|
lse = torch.logsumexp(lse_allgather, dim=0)
|
|
# The losses are going to be lse_local - predicted_logit, we just have to subtract
|
|
# the lse_local and add the lse (global).
|
|
rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode='floor')
|
|
lse_local = lse_allgather[rank_per_sample,
|
|
torch.arange(batch, device=lse_allgather.device)]
|
|
|
|
handle_losses.wait()
|
|
losses += lse - lse_local
|
|
losses.masked_fill_(ignored_mask, 0)
|
|
|
|
ctx.save_for_backward(logits_parallel, lse, labels_local)
|
|
ctx.smoothing = smoothing
|
|
ctx.ignored_index = ignored_index
|
|
ctx.inplace_backward = inplace_backward
|
|
return losses
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_loss):
|
|
logits_parallel, lse, labels = ctx.saved_tensors
|
|
if not grad_loss.is_contiguous():
|
|
grad_loss = grad_loss.contiguous()
|
|
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
|
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
|
|
ctx.smoothing, ctx.inplace_backward)
|
|
return grad_logits, None, None, None, None, None
|
|
|
|
|
|
class CrossEntropyLossParallel(nn.Module):
|
|
|
|
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
|
inplace_backward=False):
|
|
super().__init__()
|
|
if reduction not in ['mean', 'none']:
|
|
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
|
self.ignore_index = ignore_index
|
|
self.reduction = reduction
|
|
self.label_smoothing = label_smoothing
|
|
self.inplace_backward = inplace_backward
|
|
|
|
def forward(self, input, target):
|
|
assert input.is_cuda and target.is_cuda
|
|
# SoftmaxCrossEntropyLoss implicitly casts to float
|
|
loss = SoftmaxCrossEntropyLossParallelFn.apply(
|
|
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
|
|
)
|
|
if self.reduction == 'mean':
|
|
return loss.sum() / (target != self.ignore_index).sum()
|
|
else:
|
|
return loss
|