diff --git a/csrc/xentropy/interface.cpp b/csrc/xentropy/interface.cpp new file mode 100644 index 0000000..715790d --- /dev/null +++ b/csrc/xentropy/interface.cpp @@ -0,0 +1,51 @@ +#include + +// CUDA forward declarations +std::vector softmax_xentropy_cuda( + const at::Tensor &input, + const at::Tensor &labels, + const float smoothing); + +at::Tensor softmax_xentropy_backward_cuda( + const at::Tensor &grad_loss, + at::Tensor &logits, + const at::Tensor &max_log_sum_exp, + const at::Tensor &labels, + const float smoothing, + const bool inplace); + +// C++ interface + +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector softmax_xentropy_forward( + const at::Tensor &input, + const at::Tensor &labels, + const float smoothing) { + CHECK_CUDA(input); + CHECK_INPUT(labels); + + return softmax_xentropy_cuda(input, labels, smoothing); +} + +at::Tensor softmax_xentropy_backward( + const at::Tensor &grad_loss, + at::Tensor &logits, + const at::Tensor &max_log_sum_exp, + const at::Tensor &labels, + const float smoothing, + const bool inplace) { + CHECK_CUDA(grad_loss); + CHECK_CUDA(logits); + CHECK_INPUT(max_log_sum_exp); + CHECK_INPUT(labels); + + return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)"); + m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)"); +} diff --git a/csrc/xentropy/setup.py b/csrc/xentropy/setup.py new file mode 100644 index 0000000..ca61835 --- /dev/null +++ b/csrc/xentropy/setup.py @@ -0,0 +1,131 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from setuptools import setup, find_packages +import subprocess + +import sys +import warnings +import os + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + torch_binary_major = torch.version.cuda.split(".")[0] + torch_binary_minor = torch.version.cuda.split(".")[1] + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) == 11: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + if int(bare_metal_minor) > 0: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + +raise_if_cuda_home_none("--xentropy") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +cc_flag.append("-gencode") +cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") + +ext_modules.append( + CUDAExtension( + name="xentropy_cuda_lib", + sources=[ + "interface.cpp", + "xentropy_kernel.cu" + ], + extra_compile_args={ + "cxx": ["-O3"] + generator_flag, + "nvcc": append_nvcc_threads( + ["-O3"] + + generator_flag + + cc_flag + ), + }, + include_dirs=[this_dir], + ) +) + +setup( + name="xentropy_cuda_lib", + version="0.1", + description="Cross-entropy loss", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, +) diff --git a/csrc/xentropy/xentropy_kernel.cu b/csrc/xentropy/xentropy_kernel.cu new file mode 100644 index 0000000..b1ebf70 --- /dev/null +++ b/csrc/xentropy/xentropy_kernel.cu @@ -0,0 +1,754 @@ +// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu +// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory). +/** + * From PyTorch: + * + * Copyright (c) 2016- Facebook, Inc (Adam Paszke) + * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) + * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) + * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) + * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) + * Copyright (c) 2011-2013 NYU (Clement Farabet) + * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) + * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) + * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + * + * From Caffe2: + * + * Copyright (c) 2016-present, Facebook Inc. All rights reserved. + * + * All contributions by Facebook: + * Copyright (c) 2016 Facebook Inc. + * + * All contributions by Google: + * Copyright (c) 2015 Google Inc. + * All rights reserved. + * + * All contributions by Yangqing Jia: + * Copyright (c) 2015 Yangqing Jia + * All rights reserved. + * + * All contributions from Caffe: + * Copyright(c) 2013, 2014, 2015, the respective contributors + * All rights reserved. + * + * All other contributions: + * Copyright(c) 2015, 2016 the respective contributors + * All rights reserved. + * + * Caffe2 uses a copyright model similar to Caffe: each contributor holds + * copyright over their contributions to Caffe2. The project versioning records + * all such contribution and copyright details. If a contributor wants to further + * mark their specific copyright on a particular contribution, they should + * indicate their copyright solely in the commit message of the change when it is + * committed. + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + * and IDIAP Research Institute nor the names of its contributors may be + * used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#include +#include +#include + +#include +#include + +// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } +// #else +// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ +// switch(TYPE) \ +// { \ +// case at::ScalarType::Float: \ +// { \ +// using scalar_t_##LEVEL = float; \ +// __VA_ARGS__; \ +// break; \ +// } \ +// case at::ScalarType::Half: \ +// { \ +// using scalar_t_##LEVEL = at::Half; \ +// __VA_ARGS__; \ +// break; \ +// } \ +// default: \ +// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +// } +// #endif + +#define ALIGN_BYTES 16 + +using Tensor = at::Tensor; +using TensorList = at::TensorList; +using ScalarType = at::ScalarType; +using at::acc_type; + +template +struct LogSoftMaxForwardEpilogue { + __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) + : logsum(max_input + std::log(sum)) {} + + __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) + : logsum(max_log_sum_exp) {} + + __device__ __forceinline__ OutT operator()(T input) const { + return static_cast(input - logsum); + } + + const AccumT logsum; +}; + +template +struct LogSoftMaxBackwardEpilogue { + __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) + : sum(sum) {} + + __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { + return static_cast(gradOutput - std::exp(static_cast(output)) * sum); + } + + const AccumT sum; +}; + + + +const int max_threads = 1024; + +inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { + uint64_t block_size = 1; + uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); + while (block_size < (max_block_size/2)) block_size *= 2; + // Launch at least a single warp - the kernel assumes that. + block_size = std::max(block_size, static_cast(32)); + return dim3(block_size); +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +// Regular kernel (fast when dim_size is large; requires inner_size == 1) +//////////////////////////////////////////////////////////////////////////////// + + +template +struct MaxFloat +{ + __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { + return ::max(max, (AccumT)v); + } +}; + +template +struct AddFloat +{ + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + v; + } +}; + +template +struct SumExpFloat +{ + __device__ __forceinline__ SumExpFloat(AccumT v) + : max_k(v) {} + + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + std::exp(v - max_k); + } + + const AccumT max_k; +}; + +template class Reduction, typename AccumT> +__device__ __forceinline__ AccumT +blockReduce(AccumT* smem, AccumT val, + const Reduction& r, + AccumT defaultVal) +{ + // To avoid RaW races from chaining blockReduce calls together, we need a sync here + __syncthreads(); + + smem[threadIdx.x] = val; + + __syncthreads(); + + AccumT warpVal = defaultVal; + + // First warp will perform per-warp reductions for the remaining warps + uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; + if (threadIdx.x < 32) { + int lane = threadIdx.x % 32; + if (lane < blockDim.x / 32) { +#pragma unroll + for (int i = 0; i < 32; ++i) { + warpVal = r(warpVal, smem[lane * 32 + i]); + } + __syncwarp(mask); + smem[lane] = warpVal; + } + } + + __syncthreads(); + + // First thread will perform a reduction of the above per-warp reductions + AccumT blockVal = defaultVal; + + if (threadIdx.x == 0) { + for (int i = 0; i < blockDim.x / 32; ++i) { + blockVal = r(blockVal, smem[i]); + } + smem[0] = blockVal; + } + + // Sync and broadcast + __syncthreads(); + return smem[0]; +} + +template class Reduction1, template class Reduction2, typename AccumT> +__device__ __forceinline__ void +blockReduce(AccumT* smem, + AccumT* reducVal1, + AccumT val1, + const Reduction1& r1, + AccumT defaultVal1, + AccumT* reducVal2, + AccumT val2, + const Reduction2& r2, + AccumT defaultVal2) +{ + // To avoid RaW races from chaining blockReduce calls together, we need a sync here + __syncthreads(); + + smem[threadIdx.x] = val1; + smem[blockDim.x + threadIdx.x] = val2; + + __syncthreads(); + + AccumT warpVal1 = defaultVal1; + AccumT warpVal2 = defaultVal2; + + // First warp will perform per-warp reductions for the remaining warps + uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; + if (threadIdx.x < 32) { + int lane = threadIdx.x % 32; + if (lane < blockDim.x / 32) { +#pragma unroll + for (int i = 0; i < 32; ++i) { + warpVal1 = r1(warpVal1, smem[lane * 32 + i]); + warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); + } + __syncwarp(mask); + smem[lane] = warpVal1; + smem[lane + blockDim.x] = warpVal2; + } + } + + __syncthreads(); + + // First thread will perform a reduction of the above per-warp reductions + AccumT blockVal1 = defaultVal1; + AccumT blockVal2 = defaultVal2; + + if (threadIdx.x == 0) { + for (int i = 0; i < blockDim.x / 32; ++i) { + blockVal1 = r1(blockVal1, smem[i]); + blockVal2 = r2(blockVal2, smem[i + blockDim.x]); + } + smem[0] = blockVal1; + smem[blockDim.x] = blockVal2; + } + + // Sync and broadcast + __syncthreads(); + *reducVal1 = smem[0]; + *reducVal2 = smem[blockDim.x]; + __syncthreads(); +} + +template class Reduction, int ILP, typename T, typename AccumT> +__device__ __forceinline__ AccumT +ilpReduce(int shift, + T* data, + int size, + const Reduction& r, + AccumT defaultVal) +{ + typedef typename std::aligned_storage::type LoadT; + AccumT threadVal = defaultVal; + int offset = threadIdx.x; + + // shift and do 1 + if(shift > 0){ + data -= shift; + size += shift; + if(threadIdx.x >= shift){ + threadVal = r(threadVal, data[offset]); + } + size -= blockDim.x; + data += blockDim.x; + } + int last = size % (ILP * blockDim.x); + + T v[ILP]; + LoadT* value = reinterpret_cast(&v); + + for (; offset * ILP < (size - last); offset += blockDim.x) { + *value = reinterpret_cast(data)[offset]; + + for (int j = 0; j < ILP; ++j) { + threadVal = r(threadVal, v[j]); + } + } + + offset = size - last + threadIdx.x; + // Epilogue + for (; offset < size; offset += blockDim.x) + threadVal = r(threadVal, data[offset]); + + return threadVal; +} + +template class Reduction1, template class Reduction2, int ILP, typename T, typename AccumT> +__device__ __forceinline__ void +ilpReduce(int shift, + T* data, + int size, + AccumT* reducVal1, + const Reduction1& r1, + AccumT defaultVal1, + AccumT* reducVal2, + const Reduction2& r2, + AccumT defaultVal2) +{ + typedef typename std::aligned_storage::type LoadT; + + AccumT threadVal1 = defaultVal1; + AccumT threadVal2 = defaultVal2; + int offset = threadIdx.x; + + // shift and do 1 + if(shift > 0){ + data -= shift; + size += shift; + if(threadIdx.x >= shift){ + threadVal1 = r1(threadVal1, data[offset]); + threadVal2 = r2(threadVal2, data[offset]); + } + size -= blockDim.x; + data += blockDim.x; + } + int last = size % (ILP * blockDim.x); + + T v[ILP]; + LoadT* value = reinterpret_cast(&v); + + for (; offset * ILP < (size - last); offset += blockDim.x) { + *value = reinterpret_cast(data)[offset]; + + for (int j = 0; j < ILP; ++j) { + threadVal1 = r1(threadVal1, v[j]); + threadVal2 = r2(threadVal2, v[j]); + } + } + + offset = size - last + threadIdx.x; + // Epilogue + for (; offset < size; offset += blockDim.x) { + threadVal1 = r1(threadVal1, data[offset]); + threadVal2 = r2(threadVal2, data[offset]); + } + + *reducVal1 = threadVal1; + *reducVal2 = threadVal2; +} + +template class Epilogue> +__global__ void +cunn_SoftMaxXEntropyForward( + accscalar_t *losses, + outscalar_t *max_log_sum_exp, + scalar_t *input, + int64_t *labels, + int64_t classes, + const float smoothing) +{ + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + // forward pointers to batch[blockIdx.x] + // each block handles a sample in the mini-batch + input += blockIdx.x * classes; + //output += blockIdx.x * classes; + const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); + + int64_t label = labels[blockIdx.x]; + + // find the max and sum + accscalar_t threadMax, threadSum, max_k, sum_k; + ilpReduce( + shift, input, classes, + &threadMax, MaxFloat(), + -at::numeric_limits::max(), + &threadSum, AddFloat(), + static_cast(0)); + + blockReduce( + sdata, + &max_k, threadMax, Max(), + -at::numeric_limits::max(), + &sum_k, threadSum, Add(), + static_cast(0)); + + accscalar_t threadExp = ilpReduce(shift, input, classes, SumExpFloat(max_k), static_cast(0)); + accscalar_t sumAll = blockReduce( + sdata, threadExp, Add(), static_cast(0)); + + Epilogue epilogue(max_k, sumAll); + + // calculate per element loss with label smoothing + // reserve max + log_sum_exp for bprop + if (threadIdx.x == 0) { + accscalar_t lse = max_k + std::log(sumAll); + if ((label >= 0) && (label < classes)) { + accscalar_t log_prob = epilogue(static_cast(input[label])); + losses[blockIdx.x] = (lse - sum_k / classes) * smoothing - log_prob * (1 - smoothing); + } else { + losses[blockIdx.x] = outscalar_t(0.f); + } + max_log_sum_exp[blockIdx.x] = lse; + } +} + +template +__device__ __forceinline__ void +apply(scalar_t *gradInput, + scalar_t *logits, + outscalar_t *max_log_sum_exp, + outscalar_t *gradOutput, + int64_t *labels, + const float smoothing, + int classes) +{ + accscalar_t smooth_positives = 1.0 - smoothing; + accscalar_t smooth_negatives = smoothing / classes; + accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; + int64_t label = labels[blockIdx.x]; + accscalar_t coeff = max_log_sum_exp[blockIdx.x]; + + int offset = threadIdx.x; + int last = classes % (ILP * blockDim.x); + + for (; offset < classes - last; offset += blockDim.x * ILP) { + accscalar_t tmpLogits[ILP]; + +#pragma unroll + for (int j = 0; j < ILP; ++j) { + tmpLogits[j] = static_cast(logits[offset + j * blockDim.x]); + } + +#pragma unroll + for (int j = 0; j < ILP; ++j) + gradInput[offset + j * blockDim.x] = tmpGradOutput * ( + std::exp(tmpLogits[j] - coeff) - static_cast( + (offset + j * blockDim.x == label) ? 1 : 0) * + smooth_positives - smooth_negatives); + } + + for (; offset < classes; offset += blockDim.x) + gradInput[offset] = tmpGradOutput * (std::exp( + static_cast(logits[offset]) - coeff) - + static_cast((offset == label) ? 1 : 0) * + smooth_positives - smooth_negatives); +} + + +template +__device__ __forceinline__ void +aligned_apply(int shift, + scalar_t *gradInput, + scalar_t *logits, + outscalar_t *max_log_sum_exp, + outscalar_t *gradOutput, + int64_t *labels, + const float smoothing, + int classes) +{ + accscalar_t smooth_positives = 1.0 - smoothing; + accscalar_t smooth_negatives = smoothing / classes; + accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; + int64_t label = labels[blockIdx.x]; + accscalar_t coeff = max_log_sum_exp[blockIdx.x]; + + int offset = threadIdx.x; + + // shift and do 1 + if(shift > 0){ + logits -= shift; + gradInput -= shift; + classes += shift; + if(threadIdx.x >= shift){ + gradInput[offset] = tmpGradOutput * (std::exp( + static_cast(logits[offset]) - coeff) - + static_cast(((offset - shift) == label) ? 1 : 0) * + smooth_positives - smooth_negatives); + } + classes -= blockDim.x; + gradInput += blockDim.x; + logits += blockDim.x; + shift -= blockDim.x; + } + + int last = classes % (ILP * blockDim.x); + + typedef typename std::aligned_storage::type LoadT; + // input + scalar_t v[ILP]; + LoadT* value = reinterpret_cast(&v); + // output + scalar_t r[ILP]; + LoadT* result = reinterpret_cast(&r); + + for (; offset * ILP < (classes - last); offset += blockDim.x) { + *value = reinterpret_cast(logits)[offset]; + +#pragma unroll + for (int j = 0; j < ILP; ++j) { + r[j] = tmpGradOutput * (std::exp( + static_cast(v[j]) - coeff) - + static_cast(((ILP * offset + j - shift) == label) ? 1 : 0) * + smooth_positives - smooth_negatives); + } + reinterpret_cast(gradInput)[offset] = *result; + } + + offset = classes - last + threadIdx.x; + for (; offset < classes; offset += blockDim.x) + gradInput[offset] = tmpGradOutput * (std::exp( + static_cast(logits[offset]) - coeff) - + static_cast(((offset - shift) == label) ? 1 : 0) * + smooth_positives - smooth_negatives); + +} + +template class Epilogue> +__global__ void +cunn_SoftMaxXEntropyBackward( + scalar_t *gradInput, + scalar_t *logits, + outscalar_t *max_log_sum_exp, + outscalar_t *gradOutput, + int64_t *labels, + const float smoothing, + int classes) +{ + gradInput += blockIdx.x * classes; + logits += blockIdx.x * classes; + + // Do vectorized load/store when input/output have same alignment + const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); + const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); + if (shift == shift_){ + aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes); + } + else { + apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes); + } + +} + +template class Epilogue> +std::vector host_softmax_xentropy( + const Tensor & input_, + const Tensor & labels_, + const float smoothing){ + AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)input_.get_device()}; + + auto input = input_.contiguous(); + Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float)); + Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); + + static_assert(std::is_same, float>::value || + std::is_same, double>::value, + "accscalar_t for half should be float or double"); + AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); + AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); + AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); + AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); + + const int64_t dim = 1; + int64_t outer_size = 1; + int64_t dim_size = input.size(dim); + int64_t inner_size = 1; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + for (int64_t i = 0; i < dim; ++i) + outer_size *= input.size(i); + for (int64_t i = dim + 1; i < input.dim(); ++i) + inner_size *= input.size(i); + // This kernel spawns a block per each element in the batch. + // XXX: it assumes that inner_size == 1 + TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); + + dim3 grid(outer_size); + + using namespace at; + DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy", + using accscalar_t = at::acc_type; + const int ILP = sizeof(float4)/sizeof(scalar_t_0); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + cunn_SoftMaxXEntropyForward + <<>>( + losses.data_ptr(), max_log_sum_exp.data_ptr(), + input.data_ptr(), labels_.data_ptr(), + dim_size, smoothing + ); + ); + + C10_CUDA_CHECK(cudaGetLastError()); + + std::vector ret = {losses, max_log_sum_exp}; + return ret; +} + +template class Epilogue> +Tensor host_softmax_xentropy_backward( + const at::Tensor &grad_loss, + at::Tensor &logits_, + const at::Tensor &max_log_sum_exp, + const at::Tensor &labels, + const float smoothing, + bool inplace) { + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()}; + + const int64_t dim = 1; + Tensor gI = inplace ? logits_ : at::empty_like(logits_); + if (grad_loss.numel() == 0) { + return gI; + } + + auto grad = grad_loss.contiguous(); + auto logits = logits_.contiguous(); + + static_assert(std::is_same, float>::value || + std::is_same, double>::value, + "accscalar_t for half should be float or double"); + if (grad.dim() == 0) grad = grad.view(1); + + AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); + AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); + AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); + AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); + AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); + + int64_t outer_size = 1; + int64_t dim_size = logits.size(dim); + int64_t inner_size = 1; + for (int64_t i = 0; i < dim; ++i) + outer_size *= logits.size(i); + for (int64_t i = dim + 1; i < logits.dim(); ++i) + inner_size *= logits.size(i); + // See descriptions of kernels above. + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); + + dim3 grid(outer_size); + + DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", + using accscalar_t = acc_type; + const int ILP = sizeof(float4)/sizeof(scalar_t_0); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + cunn_SoftMaxXEntropyBackward + <<>>( + gI.data_ptr(), logits.data_ptr(), + max_log_sum_exp.data_ptr(), + grad.data_ptr(), labels.data_ptr(), + smoothing, dim_size + ); + ); + + C10_CUDA_CHECK(cudaGetLastError()); + return gI; +} + +std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing){ + return host_softmax_xentropy(input, labels, smoothing); +} + +at::Tensor softmax_xentropy_backward_cuda( + const at::Tensor &grad_loss, + at::Tensor &logits, + const at::Tensor &max_log_sum_exp, + const at::Tensor &labels, + const float smoothing, + const bool inplace) { + AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); + return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace); +} diff --git a/flash_attn/losses/cross_entropy_apex.py b/flash_attn/losses/cross_entropy_apex.py new file mode 100644 index 0000000..ef70946 --- /dev/null +++ b/flash_attn/losses/cross_entropy_apex.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn + +import xentropy_cuda_lib + + +# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False): + losses, max_log_sum_exp = xentropy_cuda_lib.forward( + logits, labels, smoothing) + losses.masked_fill_(labels==padding_idx, 0) + ctx.save_for_backward(logits, max_log_sum_exp, labels) + ctx.smoothing = smoothing + ctx.padding_idx = padding_idx + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits, max_log_sum_exp, labels = ctx.saved_tensors + if not grad_loss.is_contiguous(): + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels==ctx.padding_idx, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels, + ctx.smoothing, ctx.inplace_backward) + return grad_logits, None, None, None, None + + +class CrossEntropyLossApex(nn.Module): + + def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, + inplace_backward=False): + super().__init__() + if reduction not in ['mean', 'none']: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + + def forward(self, input, target): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing, + self.ignore_index, self.inplace_backward) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/flash_attn/losses/cross_entropy_parallel.py b/flash_attn/losses/cross_entropy_parallel.py new file mode 100644 index 0000000..84fe82d --- /dev/null +++ b/flash_attn/losses/cross_entropy_parallel.py @@ -0,0 +1,112 @@ +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py +# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and +# the losses we can get the global loss. There's no need to do it step by step +# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) +import torch +import torch.nn as nn + +import xentropy_cuda_lib + +from apex.transformer.parallel_state import get_tensor_model_parallel_group +from apex.transformer.parallel_state import get_tensor_model_parallel_rank +from apex.transformer.parallel_state import get_tensor_model_parallel_world_size +from apex.transformer.tensor_parallel.utils import VocabUtility + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 4 lines are for backward comparability with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base +if "reduce_scatter_tensor" not in dir(torch.distributed): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + + +class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100, + inplace_backward=False): + """ + logits_parallel: (batch, vocab_size / world_size) + labels: (batch,) + """ + assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it' + batch, partition_vocab_size = logits_parallel.shape + assert labels.shape == (batch,) + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, get_tensor_model_parallel_rank(), + get_tensor_model_parallel_world_size() + ) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) + masked_labels = labels_local.clone() + masked_labels[labels_mask] = ignored_index + + losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(masked_labels==ignored_index, 0) + + if world_size > 1: + lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, + device=lse_local.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), + group=get_tensor_model_parallel_group()) + lse = torch.logsumexp(lse_allgather, dim=0) + torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group()) + # The losses are currently lse_local - predicted_logit, we just have to subtract the + # lse_local and add the lse (global). + rank_per_sample = labels // partition_vocab_size + lse_local = lse_allgather[rank_per_sample, + torch.arange(batch, device=lse_allgather.device)] + losses += lse - lse_local + losses.masked_fill_(ignored_mask, 0) + else: + lse = lse_local + + ctx.save_for_backward(logits_parallel, lse, labels_local) + ctx.smoothing = smoothing + ctx.ignored_index = ignored_index + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits_parallel, lse, labels = ctx.saved_tensors + if not grad_loss.is_contiguous(): + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels==ctx.ignored_index, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels, + ctx.smoothing, ctx.inplace_backward) + return grad_logits, None, None, None, None, None + + +class CrossEntropyLossParallel(nn.Module): + + def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, + inplace_backward=False): + super().__init__() + if reduction not in ['mean', 'none']: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + + def forward(self, input, target): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossParallelFn.apply( + input, target, self.label_smoothing, self.ignore_index, self.inplace_backward + ) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/tests/losses/test_cross_entropy_apex.py b/tests/losses/test_cross_entropy_apex.py new file mode 100644 index 0000000..e5e170f --- /dev/null +++ b/tests/losses/test_cross_entropy_apex.py @@ -0,0 +1,39 @@ +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from src.losses.cross_entropy_apex import CrossEntropyLossApex + +is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 + + +@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) +# @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize('inplace_backward', [False, True]) +# @pytest.mark.parametrize('inplace_backward', [False]) +@pytest.mark.parametrize('vocab_size', [50257]) +def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype): + device = 'cuda' + rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 128 + x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True) + x = x_pt.detach().clone().requires_grad_() + y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) + y[torch.randperm(batch_size * seqlen)[:10]] = -100 + model_pt = torch.nn.CrossEntropyLoss() + model = CrossEntropyLossApex(inplace_backward=inplace_backward) + out = model(x, y) + out_pt = model_pt(x_pt.float(), y) + assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_pt.backward(g) + out.backward(g) + assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) diff --git a/tests/losses/test_cross_entropy_parallel.py b/tests/losses/test_cross_entropy_parallel.py new file mode 100644 index 0000000..71fe89f --- /dev/null +++ b/tests/losses/test_cross_entropy_parallel.py @@ -0,0 +1,56 @@ +# Run test with: +# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py + +import math + +import torch +import torch.nn.functional as F +import pytest + +from apex.transformer import parallel_state +from apex.transformer import tensor_parallel + +from src.losses.cross_entropy_parallel import CrossEntropyLossParallel + +is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 + + +@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize('inplace_backward', [False, True]) +# @pytest.mark.parametrize('inplace_backward', [False]) +@pytest.mark.parametrize('vocab_size', [50264]) +@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) +# @pytest.mark.parametrize('world_size', [2]) +def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype): + assert vocab_size % world_size == 0 + rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32 + else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + partition_vocab_size = vocab_size // world_size + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 128 + x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device, + dtype=dtype) * 10).requires_grad_() + x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_() + y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) + y[torch.randperm(batch_size * seqlen)[:10]] = -100 + model_pt = torch.nn.CrossEntropyLoss(reduction='none') + model = CrossEntropyLossParallel(reduction='none', inplace_backward=inplace_backward) + out = model(x, y) + out_pt = model_pt(x_pt.float(), y) + assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) + + g = torch.randn_like(out) + out_pt.backward(g) + out.backward(g) + assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol) + + parallel_state.destroy_model_parallel()