52 lines
2.1 KiB
Python
52 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
import xentropy_cuda_lib
|
|
|
|
|
|
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
|
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
|
|
losses, max_log_sum_exp = xentropy_cuda_lib.forward(
|
|
logits, labels, smoothing)
|
|
losses.masked_fill_(labels==padding_idx, 0)
|
|
ctx.save_for_backward(logits, max_log_sum_exp, labels)
|
|
ctx.smoothing = smoothing
|
|
ctx.padding_idx = padding_idx
|
|
ctx.inplace_backward = inplace_backward
|
|
return losses
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_loss):
|
|
logits, max_log_sum_exp, labels = ctx.saved_tensors
|
|
if not grad_loss.is_contiguous():
|
|
grad_loss = grad_loss.contiguous()
|
|
grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
|
|
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
|
|
ctx.smoothing, ctx.inplace_backward)
|
|
return grad_logits, None, None, None, None
|
|
|
|
|
|
class CrossEntropyLossApex(nn.Module):
|
|
|
|
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
|
inplace_backward=False):
|
|
super().__init__()
|
|
if reduction not in ['mean', 'none']:
|
|
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
|
self.ignore_index = ignore_index
|
|
self.reduction = reduction
|
|
self.label_smoothing = label_smoothing
|
|
self.inplace_backward = inplace_backward
|
|
|
|
def forward(self, input, target):
|
|
assert input.is_cuda and target.is_cuda
|
|
# SoftmaxCrossEntropyLoss implicitly casts to float
|
|
loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
|
|
self.ignore_index, self.inplace_backward)
|
|
if self.reduction == 'mean':
|
|
return loss.sum() / (target != self.ignore_index).sum()
|
|
else:
|
|
return loss
|