Add fused cross entropy loss

This commit is contained in:
Tri Dao 2022-11-12 19:49:33 -08:00
parent 55797f32c9
commit 7c9953815a
7 changed files with 1194 additions and 0 deletions

View File

@ -0,0 +1,51 @@
#include <torch/extension.h>
// CUDA forward declarations
std::vector<at::Tensor> softmax_xentropy_cuda(
const at::Tensor &input,
const at::Tensor &labels,
const float smoothing);
at::Tensor softmax_xentropy_backward_cuda(
const at::Tensor &grad_loss,
at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing,
const bool inplace);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> softmax_xentropy_forward(
const at::Tensor &input,
const at::Tensor &labels,
const float smoothing) {
CHECK_CUDA(input);
CHECK_INPUT(labels);
return softmax_xentropy_cuda(input, labels, smoothing);
}
at::Tensor softmax_xentropy_backward(
const at::Tensor &grad_loss,
at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing,
const bool inplace) {
CHECK_CUDA(grad_loss);
CHECK_CUDA(logits);
CHECK_INPUT(max_log_sum_exp);
CHECK_INPUT(labels);
return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)");
m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)");
}

131
csrc/xentropy/setup.py Normal file
View File

@ -0,0 +1,131 @@
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess
import sys
import warnings
import os
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+ "In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def raise_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
raise RuntimeError(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
cmdclass = {}
ext_modules = []
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
raise_if_cuda_home_none("--xentropy")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
ext_modules.append(
CUDAExtension(
name="xentropy_cuda_lib",
sources=[
"interface.cpp",
"xentropy_kernel.cu"
],
extra_compile_args={
"cxx": ["-O3"] + generator_flag,
"nvcc": append_nvcc_threads(
["-O3"]
+ generator_flag
+ cc_flag
),
},
include_dirs=[this_dir],
)
)
setup(
name="xentropy_cuda_lib",
version="0.1",
description="Cross-entropy loss",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
)

View File

@ -0,0 +1,754 @@
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu
// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory).
/**
* From PyTorch:
*
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
* Copyright (c) 2011-2013 NYU (Clement Farabet)
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
*
* From Caffe2:
*
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
*
* All contributions by Facebook:
* Copyright (c) 2016 Facebook Inc.
*
* All contributions by Google:
* Copyright (c) 2015 Google Inc.
* All rights reserved.
*
* All contributions by Yangqing Jia:
* Copyright (c) 2015 Yangqing Jia
* All rights reserved.
*
* All contributions from Caffe:
* Copyright(c) 2013, 2014, 2015, the respective contributors
* All rights reserved.
*
* All other contributions:
* Copyright(c) 2015, 2016 the respective contributors
* All rights reserved.
*
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
* copyright over their contributions to Caffe2. The project versioning records
* all such contribution and copyright details. If a contributor wants to further
* mark their specific copyright on a particular contribution, they should
* indicate their copyright solely in the commit message of the change when it is
* committed.
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
* and IDIAP Research Institute nor the names of its contributors may be
* used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh>
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
// #else
// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
// switch(TYPE) \
// { \
// case at::ScalarType::Float: \
// { \
// using scalar_t_##LEVEL = float; \
// __VA_ARGS__; \
// break; \
// } \
// case at::ScalarType::Half: \
// { \
// using scalar_t_##LEVEL = at::Half; \
// __VA_ARGS__; \
// break; \
// } \
// default: \
// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
// }
// #endif
#define ALIGN_BYTES 16
using Tensor = at::Tensor;
using TensorList = at::TensorList;
using ScalarType = at::ScalarType;
using at::acc_type;
template<typename T, typename AccumT, typename OutT>
struct LogSoftMaxForwardEpilogue {
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: logsum(max_input + std::log(sum)) {}
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)
: logsum(max_log_sum_exp) {}
__device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>(input - logsum);
}
const AccumT logsum;
};
template<typename T, typename AccumT, typename OutT>
struct LogSoftMaxBackwardEpilogue {
__device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
: sum(sum) {}
__device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
}
const AccumT sum;
};
const int max_threads = 1024;
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
while (block_size < (max_block_size/2)) block_size *= 2;
// Launch at least a single warp - the kernel assumes that.
block_size = std::max(block_size, static_cast<uint64_t>(32));
return dim3(block_size);
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
////////////////////////////////////////////////////////////////////////////////
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
////////////////////////////////////////////////////////////////////////////////
template <typename T, typename AccumT>
struct MaxFloat
{
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
return ::max(max, (AccumT)v);
}
};
template<typename T, typename AccumT>
struct AddFloat
{
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + v;
}
};
template<typename T, typename AccumT>
struct SumExpFloat
{
__device__ __forceinline__ SumExpFloat(AccumT v)
: max_k(v) {}
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + std::exp(v - max_k);
}
const AccumT max_k;
};
template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
const Reduction<AccumT>& r,
AccumT defaultVal)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();
smem[threadIdx.x] = val;
__syncthreads();
AccumT warpVal = defaultVal;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
if (threadIdx.x < 32) {
int lane = threadIdx.x % 32;
if (lane < blockDim.x / 32) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
warpVal = r(warpVal, smem[lane * 32 + i]);
}
__syncwarp(mask);
smem[lane] = warpVal;
}
}
__syncthreads();
// First thread will perform a reduction of the above per-warp reductions
AccumT blockVal = defaultVal;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) {
blockVal = r(blockVal, smem[i]);
}
smem[0] = blockVal;
}
// Sync and broadcast
__syncthreads();
return smem[0];
}
template <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>
__device__ __forceinline__ void
blockReduce(AccumT* smem,
AccumT* reducVal1,
AccumT val1,
const Reduction1<AccumT>& r1,
AccumT defaultVal1,
AccumT* reducVal2,
AccumT val2,
const Reduction2<AccumT>& r2,
AccumT defaultVal2)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();
smem[threadIdx.x] = val1;
smem[blockDim.x + threadIdx.x] = val2;
__syncthreads();
AccumT warpVal1 = defaultVal1;
AccumT warpVal2 = defaultVal2;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
if (threadIdx.x < 32) {
int lane = threadIdx.x % 32;
if (lane < blockDim.x / 32) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
warpVal1 = r1(warpVal1, smem[lane * 32 + i]);
warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);
}
__syncwarp(mask);
smem[lane] = warpVal1;
smem[lane + blockDim.x] = warpVal2;
}
}
__syncthreads();
// First thread will perform a reduction of the above per-warp reductions
AccumT blockVal1 = defaultVal1;
AccumT blockVal2 = defaultVal2;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) {
blockVal1 = r1(blockVal1, smem[i]);
blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
}
smem[0] = blockVal1;
smem[blockDim.x] = blockVal2;
}
// Sync and broadcast
__syncthreads();
*reducVal1 = smem[0];
*reducVal2 = smem[blockDim.x];
__syncthreads();
}
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT
ilpReduce(int shift,
T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal)
{
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
AccumT threadVal = defaultVal;
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
data -= shift;
size += shift;
if(threadIdx.x >= shift){
threadVal = r(threadVal, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}
int last = size % (ILP * blockDim.x);
T v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
for (; offset * ILP < (size - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(data)[offset];
for (int j = 0; j < ILP; ++j) {
threadVal = r(threadVal, v[j]);
}
}
offset = size - last + threadIdx.x;
// Epilogue
for (; offset < size; offset += blockDim.x)
threadVal = r(threadVal, data[offset]);
return threadVal;
}
template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
__device__ __forceinline__ void
ilpReduce(int shift,
T* data,
int size,
AccumT* reducVal1,
const Reduction1<T, AccumT>& r1,
AccumT defaultVal1,
AccumT* reducVal2,
const Reduction2<T, AccumT>& r2,
AccumT defaultVal2)
{
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
AccumT threadVal1 = defaultVal1;
AccumT threadVal2 = defaultVal2;
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
data -= shift;
size += shift;
if(threadIdx.x >= shift){
threadVal1 = r1(threadVal1, data[offset]);
threadVal2 = r2(threadVal2, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}
int last = size % (ILP * blockDim.x);
T v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
for (; offset * ILP < (size - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(data)[offset];
for (int j = 0; j < ILP; ++j) {
threadVal1 = r1(threadVal1, v[j]);
threadVal2 = r2(threadVal2, v[j]);
}
}
offset = size - last + threadIdx.x;
// Epilogue
for (; offset < size; offset += blockDim.x) {
threadVal1 = r1(threadVal1, data[offset]);
threadVal2 = r2(threadVal2, data[offset]);
}
*reducVal1 = threadVal1;
*reducVal2 = threadVal2;
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyForward(
accscalar_t *losses,
outscalar_t *max_log_sum_exp,
scalar_t *input,
int64_t *labels,
int64_t classes,
const float smoothing)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
// forward pointers to batch[blockIdx.x]
// each block handles a sample in the mini-batch
input += blockIdx.x * classes;
//output += blockIdx.x * classes;
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
int64_t label = labels[blockIdx.x];
// find the max and sum
accscalar_t threadMax, threadSum, max_k, sum_k;
ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
shift, input, classes,
&threadMax, MaxFloat<scalar_t, accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(),
&threadSum, AddFloat<scalar_t, accscalar_t>(),
static_cast<accscalar_t>(0));
blockReduce<Max, Add, accscalar_t>(
sdata,
&max_k, threadMax, Max<accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(),
&sum_k, threadSum, Add<accscalar_t>(),
static_cast<accscalar_t>(0));
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
accscalar_t sumAll = blockReduce<Add, accscalar_t>(
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
// calculate per element loss with label smoothing
// reserve max + log_sum_exp for bprop
if (threadIdx.x == 0) {
accscalar_t lse = max_k + std::log(sumAll);
if ((label >= 0) && (label < classes)) {
accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));
losses[blockIdx.x] = (lse - sum_k / classes) * smoothing - log_prob * (1 - smoothing);
} else {
losses[blockIdx.x] = outscalar_t(0.f);
}
max_log_sum_exp[blockIdx.x] = lse;
}
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
__device__ __forceinline__ void
apply(scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x);
for (; offset < classes - last; offset += blockDim.x * ILP) {
accscalar_t tmpLogits[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);
}
#pragma unroll
for (int j = 0; j < ILP; ++j)
gradInput[offset + j * blockDim.x] = tmpGradOutput * (
std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
(offset + j * blockDim.x == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>((offset == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
__device__ __forceinline__ void
aligned_apply(int shift,
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
logits -= shift;
gradInput -= shift;
classes += shift;
if(threadIdx.x >= shift){
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
classes -= blockDim.x;
gradInput += blockDim.x;
logits += blockDim.x;
shift -= blockDim.x;
}
int last = classes % (ILP * blockDim.x);
typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;
// input
scalar_t v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
// output
scalar_t r[ILP];
LoadT* result = reinterpret_cast<LoadT*>(&r);
for (; offset * ILP < (classes - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(logits)[offset];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
r[j] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(v[j]) - coeff) -
static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
reinterpret_cast<LoadT*>(gradInput)[offset] = *result;
}
offset = classes - last + threadIdx.x;
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyBackward(
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
gradInput += blockIdx.x * classes;
logits += blockIdx.x * classes;
// Do vectorized load/store when input/output have same alignment
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
if (shift == shift_){
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
}
else {
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
}
}
template<template<typename, typename, typename> class Epilogue>
std::vector<Tensor> host_softmax_xentropy(
const Tensor & input_,
const Tensor & labels_,
const float smoothing){
AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)input_.get_device()};
auto input = input_.contiguous();
Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float));
Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
std::is_same<acc_type<at::Half, true>, double>::value,
"accscalar_t for half should be float or double");
AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported");
AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional");
AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples");
AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0");
const int64_t dim = 1;
int64_t outer_size = 1;
int64_t dim_size = input.size(dim);
int64_t inner_size = 1;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (int64_t i = 0; i < dim; ++i)
outer_size *= input.size(i);
for (int64_t i = dim + 1; i < input.dim(); ++i)
inner_size *= input.size(i);
// This kernel spawns a block per each element in the batch.
// XXX: it assumes that inner_size == 1
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
dim3 grid(outer_size);
using namespace at;
DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy",
using accscalar_t = at::acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
dim_size, smoothing
);
);
C10_CUDA_CHECK(cudaGetLastError());
std::vector<at::Tensor> ret = {losses, max_log_sum_exp};
return ret;
}
template<template<typename, typename, typename> class Epilogue>
Tensor host_softmax_xentropy_backward(
const at::Tensor &grad_loss,
at::Tensor &logits_,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing,
bool inplace) {
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()};
const int64_t dim = 1;
Tensor gI = inplace ? logits_ : at::empty_like(logits_);
if (grad_loss.numel() == 0) {
return gI;
}
auto grad = grad_loss.contiguous();
auto logits = logits_.contiguous();
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
std::is_same<acc_type<at::Half, true>, double>::value,
"accscalar_t for half should be float or double");
if (grad.dim() == 0) grad = grad.view(1);
AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported");
AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional");
AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0");
AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples");
AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples");
int64_t outer_size = 1;
int64_t dim_size = logits.size(dim);
int64_t inner_size = 1;
for (int64_t i = 0; i < dim; ++i)
outer_size *= logits.size(i);
for (int64_t i = dim + 1; i < logits.dim(); ++i)
inner_size *= logits.size(i);
// See descriptions of kernels above.
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
dim3 grid(outer_size);
DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
using accscalar_t = acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
max_log_sum_exp.data_ptr<accscalar_t>(),
grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
smoothing, dim_size
);
);
C10_CUDA_CHECK(cudaGetLastError());
return gI;
}
std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing){
return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing);
}
at::Tensor softmax_xentropy_backward_cuda(
const at::Tensor &grad_loss,
at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing,
const bool inplace) {
AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
}

View File

@ -0,0 +1,51 @@
import torch
import torch.nn as nn
import xentropy_cuda_lib
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
losses, max_log_sum_exp = xentropy_cuda_lib.forward(
logits, labels, smoothing)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels)
ctx.smoothing = smoothing
ctx.padding_idx = padding_idx
ctx.inplace_backward = inplace_backward
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
ctx.smoothing, ctx.inplace_backward)
return grad_logits, None, None, None, None
class CrossEntropyLossApex(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False):
super().__init__()
if reduction not in ['mean', 'none']:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
self.ignore_index, self.inplace_backward)
if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss

View File

@ -0,0 +1,112 @@
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
import torch
import torch.nn as nn
import xentropy_cuda_lib
from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.tensor_parallel.utils import VocabUtility
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward comparability with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
@staticmethod
def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
inplace_backward=False):
"""
logits_parallel: (batch, vocab_size / world_size)
labels: (batch,)
"""
assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
batch, partition_vocab_size = logits_parallel.shape
assert labels.shape == (batch,)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
partition_vocab_size, get_tensor_model_parallel_rank(),
get_tensor_model_parallel_world_size()
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
ignored_mask = labels == ignored_index
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
masked_labels = labels_local.clone()
masked_labels[labels_mask] = ignored_index
losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(masked_labels==ignored_index, 0)
if world_size > 1:
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
device=lse_local.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
group=get_tensor_model_parallel_group())
lse = torch.logsumexp(lse_allgather, dim=0)
torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group())
# The losses are currently lse_local - predicted_logit, we just have to subtract the
# lse_local and add the lse (global).
rank_per_sample = labels // partition_vocab_size
lse_local = lse_allgather[rank_per_sample,
torch.arange(batch, device=lse_allgather.device)]
losses += lse - lse_local
losses.masked_fill_(ignored_mask, 0)
else:
lse = lse_local
ctx.save_for_backward(logits_parallel, lse, labels_local)
ctx.smoothing = smoothing
ctx.ignored_index = ignored_index
ctx.inplace_backward = inplace_backward
return losses
@staticmethod
def backward(ctx, grad_loss):
logits_parallel, lse, labels = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
ctx.smoothing, ctx.inplace_backward)
return grad_logits, None, None, None, None, None
class CrossEntropyLossParallel(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False):
super().__init__()
if reduction not in ['mean', 'none']:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossParallelFn.apply(
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
)
if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss

View File

@ -0,0 +1,39 @@
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from src.losses.cross_entropy_apex import CrossEntropyLossApex
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('inplace_backward', [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('vocab_size', [50257])
def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype):
device = 'cuda'
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 128
x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss()
model = CrossEntropyLossApex(inplace_backward=inplace_backward)
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
g = torch.randn_like(out)
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)

View File

@ -0,0 +1,56 @@
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py
import math
import torch
import torch.nn.functional as F
import pytest
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from src.losses.cross_entropy_parallel import CrossEntropyLossParallel
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('inplace_backward', [False, True])
# @pytest.mark.parametrize('inplace_backward', [False])
@pytest.mark.parametrize('vocab_size', [50264])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype):
assert vocab_size % world_size == 0
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
partition_vocab_size = vocab_size // world_size
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 128
x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device,
dtype=dtype) * 10).requires_grad_()
x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(reduction='none')
model = CrossEntropyLossParallel(reduction='none', inplace_backward=inplace_backward)
out = model(x, y)
out_pt = model_pt(x_pt.float(), y)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
g = torch.randn_like(out)
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol)
parallel_state.destroy_model_parallel()