Add fused cross entropy loss
This commit is contained in:
parent
55797f32c9
commit
7c9953815a
51
csrc/xentropy/interface.cpp
Normal file
51
csrc/xentropy/interface.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
std::vector<at::Tensor> softmax_xentropy_cuda(
|
||||
const at::Tensor &input,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing);
|
||||
|
||||
at::Tensor softmax_xentropy_backward_cuda(
|
||||
const at::Tensor &grad_loss,
|
||||
at::Tensor &logits,
|
||||
const at::Tensor &max_log_sum_exp,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing,
|
||||
const bool inplace);
|
||||
|
||||
// C++ interface
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<at::Tensor> softmax_xentropy_forward(
|
||||
const at::Tensor &input,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_INPUT(labels);
|
||||
|
||||
return softmax_xentropy_cuda(input, labels, smoothing);
|
||||
}
|
||||
|
||||
at::Tensor softmax_xentropy_backward(
|
||||
const at::Tensor &grad_loss,
|
||||
at::Tensor &logits,
|
||||
const at::Tensor &max_log_sum_exp,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing,
|
||||
const bool inplace) {
|
||||
CHECK_CUDA(grad_loss);
|
||||
CHECK_CUDA(logits);
|
||||
CHECK_INPUT(max_log_sum_exp);
|
||||
CHECK_INPUT(labels);
|
||||
|
||||
return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)");
|
||||
m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)");
|
||||
}
|
||||
131
csrc/xentropy/setup.py
Normal file
131
csrc/xentropy/setup.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||
torch_binary_minor = torch.version.cuda.split(".")[1]
|
||||
|
||||
print("\nCompiling cuda extensions with")
|
||||
print(raw_output + "from " + cuda_dir + "/bin\n")
|
||||
|
||||
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
|
||||
raise RuntimeError(
|
||||
"Cuda extensions are being compiled with a version of Cuda that does "
|
||||
"not match the version used to compile Pytorch binaries. "
|
||||
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
|
||||
+ "In some cases, a minor-version mismatch will not cause later errors: "
|
||||
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
|
||||
"You can try commenting out this check (at your own risk)."
|
||||
)
|
||||
|
||||
|
||||
def raise_if_cuda_home_none(global_option: str) -> None:
|
||||
if CUDA_HOME is not None:
|
||||
return
|
||||
raise RuntimeError(
|
||||
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
||||
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
||||
"only images whose names contain 'devel' will provide nvcc."
|
||||
)
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
# https://github.com/NVIDIA/apex/issues/486
|
||||
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
|
||||
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
|
||||
print(
|
||||
"\nWarning: Torch did not find available GPUs on this system.\n",
|
||||
"If your intention is to cross-compile, this is not an error.\n"
|
||||
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
|
||||
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
|
||||
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
|
||||
"If you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
|
||||
)
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) == 11:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|
||||
if int(bare_metal_minor) > 0:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
|
||||
cmdclass = {}
|
||||
ext_modules = []
|
||||
|
||||
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
|
||||
# See https://github.com/pytorch/pytorch/pull/70650
|
||||
generator_flag = []
|
||||
torch_dir = torch.__path__[0]
|
||||
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
|
||||
generator_flag = ["-DOLD_GENERATOR_PATH"]
|
||||
|
||||
raise_if_cuda_home_none("--xentropy")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_70,code=sm_70")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="xentropy_cuda_lib",
|
||||
sources=[
|
||||
"interface.cpp",
|
||||
"xentropy_kernel.cu"
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3"] + generator_flag,
|
||||
"nvcc": append_nvcc_threads(
|
||||
["-O3"]
|
||||
+ generator_flag
|
||||
+ cc_flag
|
||||
),
|
||||
},
|
||||
include_dirs=[this_dir],
|
||||
)
|
||||
)
|
||||
|
||||
setup(
|
||||
name="xentropy_cuda_lib",
|
||||
version="0.1",
|
||||
description="Cross-entropy loss",
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
|
||||
)
|
||||
754
csrc/xentropy/xentropy_kernel.cu
Normal file
754
csrc/xentropy/xentropy_kernel.cu
Normal file
@ -0,0 +1,754 @@
|
||||
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu
|
||||
// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory).
|
||||
/**
|
||||
* From PyTorch:
|
||||
*
|
||||
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||
* Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||
*
|
||||
* From Caffe2:
|
||||
*
|
||||
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||
*
|
||||
* All contributions by Facebook:
|
||||
* Copyright (c) 2016 Facebook Inc.
|
||||
*
|
||||
* All contributions by Google:
|
||||
* Copyright (c) 2015 Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* All contributions by Yangqing Jia:
|
||||
* Copyright (c) 2015 Yangqing Jia
|
||||
* All rights reserved.
|
||||
*
|
||||
* All contributions from Caffe:
|
||||
* Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
* All rights reserved.
|
||||
*
|
||||
* All other contributions:
|
||||
* Copyright(c) 2015, 2016 the respective contributors
|
||||
* All rights reserved.
|
||||
*
|
||||
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||
* copyright over their contributions to Caffe2. The project versioning records
|
||||
* all such contribution and copyright details. If a contributor wants to further
|
||||
* mark their specific copyright on a particular contribution, they should
|
||||
* indicate their copyright solely in the commit message of the change when it is
|
||||
* committed.
|
||||
*
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||
* and IDIAP Research Institute nor the names of its contributors may be
|
||||
* used to endorse or promote products derived from this software without
|
||||
* specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
* POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/NumericLimits.cuh>
|
||||
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
|
||||
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
|
||||
switch(TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
// #else
|
||||
// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
|
||||
// switch(TYPE) \
|
||||
// { \
|
||||
// case at::ScalarType::Float: \
|
||||
// { \
|
||||
// using scalar_t_##LEVEL = float; \
|
||||
// __VA_ARGS__; \
|
||||
// break; \
|
||||
// } \
|
||||
// case at::ScalarType::Half: \
|
||||
// { \
|
||||
// using scalar_t_##LEVEL = at::Half; \
|
||||
// __VA_ARGS__; \
|
||||
// break; \
|
||||
// } \
|
||||
// default: \
|
||||
// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
// }
|
||||
// #endif
|
||||
|
||||
#define ALIGN_BYTES 16
|
||||
|
||||
using Tensor = at::Tensor;
|
||||
using TensorList = at::TensorList;
|
||||
using ScalarType = at::ScalarType;
|
||||
using at::acc_type;
|
||||
|
||||
template<typename T, typename AccumT, typename OutT>
|
||||
struct LogSoftMaxForwardEpilogue {
|
||||
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
|
||||
: logsum(max_input + std::log(sum)) {}
|
||||
|
||||
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)
|
||||
: logsum(max_log_sum_exp) {}
|
||||
|
||||
__device__ __forceinline__ OutT operator()(T input) const {
|
||||
return static_cast<OutT>(input - logsum);
|
||||
}
|
||||
|
||||
const AccumT logsum;
|
||||
};
|
||||
|
||||
template<typename T, typename AccumT, typename OutT>
|
||||
struct LogSoftMaxBackwardEpilogue {
|
||||
__device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
|
||||
: sum(sum) {}
|
||||
|
||||
__device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
|
||||
return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
|
||||
}
|
||||
|
||||
const AccumT sum;
|
||||
};
|
||||
|
||||
|
||||
|
||||
const int max_threads = 1024;
|
||||
|
||||
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
|
||||
uint64_t block_size = 1;
|
||||
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
|
||||
while (block_size < (max_block_size/2)) block_size *= 2;
|
||||
// Launch at least a single warp - the kernel assumes that.
|
||||
block_size = std::max(block_size, static_cast<uint64_t>(32));
|
||||
return dim3(block_size);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template <typename T, typename AccumT>
|
||||
struct MaxFloat
|
||||
{
|
||||
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
|
||||
return ::max(max, (AccumT)v);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename AccumT>
|
||||
struct AddFloat
|
||||
{
|
||||
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
|
||||
return sum + v;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename AccumT>
|
||||
struct SumExpFloat
|
||||
{
|
||||
__device__ __forceinline__ SumExpFloat(AccumT v)
|
||||
: max_k(v) {}
|
||||
|
||||
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
|
||||
return sum + std::exp(v - max_k);
|
||||
}
|
||||
|
||||
const AccumT max_k;
|
||||
};
|
||||
|
||||
template <template<typename> class Reduction, typename AccumT>
|
||||
__device__ __forceinline__ AccumT
|
||||
blockReduce(AccumT* smem, AccumT val,
|
||||
const Reduction<AccumT>& r,
|
||||
AccumT defaultVal)
|
||||
{
|
||||
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
|
||||
__syncthreads();
|
||||
|
||||
smem[threadIdx.x] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
AccumT warpVal = defaultVal;
|
||||
|
||||
// First warp will perform per-warp reductions for the remaining warps
|
||||
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
|
||||
if (threadIdx.x < 32) {
|
||||
int lane = threadIdx.x % 32;
|
||||
if (lane < blockDim.x / 32) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
warpVal = r(warpVal, smem[lane * 32 + i]);
|
||||
}
|
||||
__syncwarp(mask);
|
||||
smem[lane] = warpVal;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// First thread will perform a reduction of the above per-warp reductions
|
||||
AccumT blockVal = defaultVal;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < blockDim.x / 32; ++i) {
|
||||
blockVal = r(blockVal, smem[i]);
|
||||
}
|
||||
smem[0] = blockVal;
|
||||
}
|
||||
|
||||
// Sync and broadcast
|
||||
__syncthreads();
|
||||
return smem[0];
|
||||
}
|
||||
|
||||
template <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>
|
||||
__device__ __forceinline__ void
|
||||
blockReduce(AccumT* smem,
|
||||
AccumT* reducVal1,
|
||||
AccumT val1,
|
||||
const Reduction1<AccumT>& r1,
|
||||
AccumT defaultVal1,
|
||||
AccumT* reducVal2,
|
||||
AccumT val2,
|
||||
const Reduction2<AccumT>& r2,
|
||||
AccumT defaultVal2)
|
||||
{
|
||||
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
|
||||
__syncthreads();
|
||||
|
||||
smem[threadIdx.x] = val1;
|
||||
smem[blockDim.x + threadIdx.x] = val2;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
AccumT warpVal1 = defaultVal1;
|
||||
AccumT warpVal2 = defaultVal2;
|
||||
|
||||
// First warp will perform per-warp reductions for the remaining warps
|
||||
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
|
||||
if (threadIdx.x < 32) {
|
||||
int lane = threadIdx.x % 32;
|
||||
if (lane < blockDim.x / 32) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
warpVal1 = r1(warpVal1, smem[lane * 32 + i]);
|
||||
warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);
|
||||
}
|
||||
__syncwarp(mask);
|
||||
smem[lane] = warpVal1;
|
||||
smem[lane + blockDim.x] = warpVal2;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// First thread will perform a reduction of the above per-warp reductions
|
||||
AccumT blockVal1 = defaultVal1;
|
||||
AccumT blockVal2 = defaultVal2;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < blockDim.x / 32; ++i) {
|
||||
blockVal1 = r1(blockVal1, smem[i]);
|
||||
blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
|
||||
}
|
||||
smem[0] = blockVal1;
|
||||
smem[blockDim.x] = blockVal2;
|
||||
}
|
||||
|
||||
// Sync and broadcast
|
||||
__syncthreads();
|
||||
*reducVal1 = smem[0];
|
||||
*reducVal2 = smem[blockDim.x];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
|
||||
__device__ __forceinline__ AccumT
|
||||
ilpReduce(int shift,
|
||||
T* data,
|
||||
int size,
|
||||
const Reduction<T, AccumT>& r,
|
||||
AccumT defaultVal)
|
||||
{
|
||||
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
|
||||
AccumT threadVal = defaultVal;
|
||||
int offset = threadIdx.x;
|
||||
|
||||
// shift and do 1
|
||||
if(shift > 0){
|
||||
data -= shift;
|
||||
size += shift;
|
||||
if(threadIdx.x >= shift){
|
||||
threadVal = r(threadVal, data[offset]);
|
||||
}
|
||||
size -= blockDim.x;
|
||||
data += blockDim.x;
|
||||
}
|
||||
int last = size % (ILP * blockDim.x);
|
||||
|
||||
T v[ILP];
|
||||
LoadT* value = reinterpret_cast<LoadT*>(&v);
|
||||
|
||||
for (; offset * ILP < (size - last); offset += blockDim.x) {
|
||||
*value = reinterpret_cast<LoadT*>(data)[offset];
|
||||
|
||||
for (int j = 0; j < ILP; ++j) {
|
||||
threadVal = r(threadVal, v[j]);
|
||||
}
|
||||
}
|
||||
|
||||
offset = size - last + threadIdx.x;
|
||||
// Epilogue
|
||||
for (; offset < size; offset += blockDim.x)
|
||||
threadVal = r(threadVal, data[offset]);
|
||||
|
||||
return threadVal;
|
||||
}
|
||||
|
||||
template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
|
||||
__device__ __forceinline__ void
|
||||
ilpReduce(int shift,
|
||||
T* data,
|
||||
int size,
|
||||
AccumT* reducVal1,
|
||||
const Reduction1<T, AccumT>& r1,
|
||||
AccumT defaultVal1,
|
||||
AccumT* reducVal2,
|
||||
const Reduction2<T, AccumT>& r2,
|
||||
AccumT defaultVal2)
|
||||
{
|
||||
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
|
||||
|
||||
AccumT threadVal1 = defaultVal1;
|
||||
AccumT threadVal2 = defaultVal2;
|
||||
int offset = threadIdx.x;
|
||||
|
||||
// shift and do 1
|
||||
if(shift > 0){
|
||||
data -= shift;
|
||||
size += shift;
|
||||
if(threadIdx.x >= shift){
|
||||
threadVal1 = r1(threadVal1, data[offset]);
|
||||
threadVal2 = r2(threadVal2, data[offset]);
|
||||
}
|
||||
size -= blockDim.x;
|
||||
data += blockDim.x;
|
||||
}
|
||||
int last = size % (ILP * blockDim.x);
|
||||
|
||||
T v[ILP];
|
||||
LoadT* value = reinterpret_cast<LoadT*>(&v);
|
||||
|
||||
for (; offset * ILP < (size - last); offset += blockDim.x) {
|
||||
*value = reinterpret_cast<LoadT*>(data)[offset];
|
||||
|
||||
for (int j = 0; j < ILP; ++j) {
|
||||
threadVal1 = r1(threadVal1, v[j]);
|
||||
threadVal2 = r2(threadVal2, v[j]);
|
||||
}
|
||||
}
|
||||
|
||||
offset = size - last + threadIdx.x;
|
||||
// Epilogue
|
||||
for (; offset < size; offset += blockDim.x) {
|
||||
threadVal1 = r1(threadVal1, data[offset]);
|
||||
threadVal2 = r2(threadVal2, data[offset]);
|
||||
}
|
||||
|
||||
*reducVal1 = threadVal1;
|
||||
*reducVal2 = threadVal2;
|
||||
}
|
||||
|
||||
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
|
||||
__global__ void
|
||||
cunn_SoftMaxXEntropyForward(
|
||||
accscalar_t *losses,
|
||||
outscalar_t *max_log_sum_exp,
|
||||
scalar_t *input,
|
||||
int64_t *labels,
|
||||
int64_t classes,
|
||||
const float smoothing)
|
||||
{
|
||||
extern __shared__ unsigned char smem[];
|
||||
auto sdata = reinterpret_cast<accscalar_t*>(smem);
|
||||
// forward pointers to batch[blockIdx.x]
|
||||
// each block handles a sample in the mini-batch
|
||||
input += blockIdx.x * classes;
|
||||
//output += blockIdx.x * classes;
|
||||
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
|
||||
|
||||
int64_t label = labels[blockIdx.x];
|
||||
|
||||
// find the max and sum
|
||||
accscalar_t threadMax, threadSum, max_k, sum_k;
|
||||
ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
|
||||
shift, input, classes,
|
||||
&threadMax, MaxFloat<scalar_t, accscalar_t>(),
|
||||
-at::numeric_limits<accscalar_t>::max(),
|
||||
&threadSum, AddFloat<scalar_t, accscalar_t>(),
|
||||
static_cast<accscalar_t>(0));
|
||||
|
||||
blockReduce<Max, Add, accscalar_t>(
|
||||
sdata,
|
||||
&max_k, threadMax, Max<accscalar_t>(),
|
||||
-at::numeric_limits<accscalar_t>::max(),
|
||||
&sum_k, threadSum, Add<accscalar_t>(),
|
||||
static_cast<accscalar_t>(0));
|
||||
|
||||
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
|
||||
accscalar_t sumAll = blockReduce<Add, accscalar_t>(
|
||||
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
|
||||
|
||||
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
|
||||
|
||||
// calculate per element loss with label smoothing
|
||||
// reserve max + log_sum_exp for bprop
|
||||
if (threadIdx.x == 0) {
|
||||
accscalar_t lse = max_k + std::log(sumAll);
|
||||
if ((label >= 0) && (label < classes)) {
|
||||
accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));
|
||||
losses[blockIdx.x] = (lse - sum_k / classes) * smoothing - log_prob * (1 - smoothing);
|
||||
} else {
|
||||
losses[blockIdx.x] = outscalar_t(0.f);
|
||||
}
|
||||
max_log_sum_exp[blockIdx.x] = lse;
|
||||
}
|
||||
}
|
||||
|
||||
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
|
||||
__device__ __forceinline__ void
|
||||
apply(scalar_t *gradInput,
|
||||
scalar_t *logits,
|
||||
outscalar_t *max_log_sum_exp,
|
||||
outscalar_t *gradOutput,
|
||||
int64_t *labels,
|
||||
const float smoothing,
|
||||
int classes)
|
||||
{
|
||||
accscalar_t smooth_positives = 1.0 - smoothing;
|
||||
accscalar_t smooth_negatives = smoothing / classes;
|
||||
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
|
||||
int64_t label = labels[blockIdx.x];
|
||||
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
|
||||
|
||||
int offset = threadIdx.x;
|
||||
int last = classes % (ILP * blockDim.x);
|
||||
|
||||
for (; offset < classes - last; offset += blockDim.x * ILP) {
|
||||
accscalar_t tmpLogits[ILP];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ILP; ++j) {
|
||||
tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ILP; ++j)
|
||||
gradInput[offset + j * blockDim.x] = tmpGradOutput * (
|
||||
std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
|
||||
(offset + j * blockDim.x == label) ? 1 : 0) *
|
||||
smooth_positives - smooth_negatives);
|
||||
}
|
||||
|
||||
for (; offset < classes; offset += blockDim.x)
|
||||
gradInput[offset] = tmpGradOutput * (std::exp(
|
||||
static_cast<accscalar_t>(logits[offset]) - coeff) -
|
||||
static_cast<accscalar_t>((offset == label) ? 1 : 0) *
|
||||
smooth_positives - smooth_negatives);
|
||||
}
|
||||
|
||||
|
||||
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
|
||||
__device__ __forceinline__ void
|
||||
aligned_apply(int shift,
|
||||
scalar_t *gradInput,
|
||||
scalar_t *logits,
|
||||
outscalar_t *max_log_sum_exp,
|
||||
outscalar_t *gradOutput,
|
||||
int64_t *labels,
|
||||
const float smoothing,
|
||||
int classes)
|
||||
{
|
||||
accscalar_t smooth_positives = 1.0 - smoothing;
|
||||
accscalar_t smooth_negatives = smoothing / classes;
|
||||
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
|
||||
int64_t label = labels[blockIdx.x];
|
||||
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
|
||||
|
||||
int offset = threadIdx.x;
|
||||
|
||||
// shift and do 1
|
||||
if(shift > 0){
|
||||
logits -= shift;
|
||||
gradInput -= shift;
|
||||
classes += shift;
|
||||
if(threadIdx.x >= shift){
|
||||
gradInput[offset] = tmpGradOutput * (std::exp(
|
||||
static_cast<accscalar_t>(logits[offset]) - coeff) -
|
||||
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
|
||||
smooth_positives - smooth_negatives);
|
||||
}
|
||||
classes -= blockDim.x;
|
||||
gradInput += blockDim.x;
|
||||
logits += blockDim.x;
|
||||
shift -= blockDim.x;
|
||||
}
|
||||
|
||||
int last = classes % (ILP * blockDim.x);
|
||||
|
||||
typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;
|
||||
// input
|
||||
scalar_t v[ILP];
|
||||
LoadT* value = reinterpret_cast<LoadT*>(&v);
|
||||
// output
|
||||
scalar_t r[ILP];
|
||||
LoadT* result = reinterpret_cast<LoadT*>(&r);
|
||||
|
||||
for (; offset * ILP < (classes - last); offset += blockDim.x) {
|
||||
*value = reinterpret_cast<LoadT*>(logits)[offset];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ILP; ++j) {
|
||||
r[j] = tmpGradOutput * (std::exp(
|
||||
static_cast<accscalar_t>(v[j]) - coeff) -
|
||||
static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *
|
||||
smooth_positives - smooth_negatives);
|
||||
}
|
||||
reinterpret_cast<LoadT*>(gradInput)[offset] = *result;
|
||||
}
|
||||
|
||||
offset = classes - last + threadIdx.x;
|
||||
for (; offset < classes; offset += blockDim.x)
|
||||
gradInput[offset] = tmpGradOutput * (std::exp(
|
||||
static_cast<accscalar_t>(logits[offset]) - coeff) -
|
||||
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
|
||||
smooth_positives - smooth_negatives);
|
||||
|
||||
}
|
||||
|
||||
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
|
||||
__global__ void
|
||||
cunn_SoftMaxXEntropyBackward(
|
||||
scalar_t *gradInput,
|
||||
scalar_t *logits,
|
||||
outscalar_t *max_log_sum_exp,
|
||||
outscalar_t *gradOutput,
|
||||
int64_t *labels,
|
||||
const float smoothing,
|
||||
int classes)
|
||||
{
|
||||
gradInput += blockIdx.x * classes;
|
||||
logits += blockIdx.x * classes;
|
||||
|
||||
// Do vectorized load/store when input/output have same alignment
|
||||
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
|
||||
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
|
||||
if (shift == shift_){
|
||||
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
|
||||
}
|
||||
else {
|
||||
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<template<typename, typename, typename> class Epilogue>
|
||||
std::vector<Tensor> host_softmax_xentropy(
|
||||
const Tensor & input_,
|
||||
const Tensor & labels_,
|
||||
const float smoothing){
|
||||
AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)input_.get_device()};
|
||||
|
||||
auto input = input_.contiguous();
|
||||
Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float));
|
||||
Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));
|
||||
|
||||
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
|
||||
std::is_same<acc_type<at::Half, true>, double>::value,
|
||||
"accscalar_t for half should be float or double");
|
||||
AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported");
|
||||
AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional");
|
||||
AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples");
|
||||
AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0");
|
||||
|
||||
const int64_t dim = 1;
|
||||
int64_t outer_size = 1;
|
||||
int64_t dim_size = input.size(dim);
|
||||
int64_t inner_size = 1;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
for (int64_t i = 0; i < dim; ++i)
|
||||
outer_size *= input.size(i);
|
||||
for (int64_t i = dim + 1; i < input.dim(); ++i)
|
||||
inner_size *= input.size(i);
|
||||
// This kernel spawns a block per each element in the batch.
|
||||
// XXX: it assumes that inner_size == 1
|
||||
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
|
||||
|
||||
dim3 grid(outer_size);
|
||||
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy",
|
||||
using accscalar_t = at::acc_type<scalar_t_0, true>;
|
||||
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
|
||||
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
|
||||
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
|
||||
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
|
||||
losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
|
||||
input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
|
||||
dim_size, smoothing
|
||||
);
|
||||
);
|
||||
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
std::vector<at::Tensor> ret = {losses, max_log_sum_exp};
|
||||
return ret;
|
||||
}
|
||||
|
||||
template<template<typename, typename, typename> class Epilogue>
|
||||
Tensor host_softmax_xentropy_backward(
|
||||
const at::Tensor &grad_loss,
|
||||
at::Tensor &logits_,
|
||||
const at::Tensor &max_log_sum_exp,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing,
|
||||
bool inplace) {
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()};
|
||||
|
||||
const int64_t dim = 1;
|
||||
Tensor gI = inplace ? logits_ : at::empty_like(logits_);
|
||||
if (grad_loss.numel() == 0) {
|
||||
return gI;
|
||||
}
|
||||
|
||||
auto grad = grad_loss.contiguous();
|
||||
auto logits = logits_.contiguous();
|
||||
|
||||
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
|
||||
std::is_same<acc_type<at::Half, true>, double>::value,
|
||||
"accscalar_t for half should be float or double");
|
||||
if (grad.dim() == 0) grad = grad.view(1);
|
||||
|
||||
AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported");
|
||||
AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional");
|
||||
AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0");
|
||||
AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples");
|
||||
AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples");
|
||||
|
||||
int64_t outer_size = 1;
|
||||
int64_t dim_size = logits.size(dim);
|
||||
int64_t inner_size = 1;
|
||||
for (int64_t i = 0; i < dim; ++i)
|
||||
outer_size *= logits.size(i);
|
||||
for (int64_t i = dim + 1; i < logits.dim(); ++i)
|
||||
inner_size *= logits.size(i);
|
||||
// See descriptions of kernels above.
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
|
||||
|
||||
dim3 grid(outer_size);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
|
||||
using accscalar_t = acc_type<scalar_t_0, true>;
|
||||
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
|
||||
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
|
||||
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
|
||||
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
|
||||
gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
|
||||
max_log_sum_exp.data_ptr<accscalar_t>(),
|
||||
grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
|
||||
smoothing, dim_size
|
||||
);
|
||||
);
|
||||
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
return gI;
|
||||
}
|
||||
|
||||
std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing){
|
||||
return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing);
|
||||
}
|
||||
|
||||
at::Tensor softmax_xentropy_backward_cuda(
|
||||
const at::Tensor &grad_loss,
|
||||
at::Tensor &logits,
|
||||
const at::Tensor &max_log_sum_exp,
|
||||
const at::Tensor &labels,
|
||||
const float smoothing,
|
||||
const bool inplace) {
|
||||
AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
|
||||
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
|
||||
}
|
||||
51
flash_attn/losses/cross_entropy_apex.py
Normal file
51
flash_attn/losses/cross_entropy_apex.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
|
||||
losses, max_log_sum_exp = xentropy_cuda_lib.forward(
|
||||
logits, labels, smoothing)
|
||||
losses.masked_fill_(labels==padding_idx, 0)
|
||||
ctx.save_for_backward(logits, max_log_sum_exp, labels)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.padding_idx = padding_idx
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits, max_log_sum_exp, labels = ctx.saved_tensors
|
||||
if not grad_loss.is_contiguous():
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
|
||||
ctx.smoothing, ctx.inplace_backward)
|
||||
return grad_logits, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLossApex(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
|
||||
def forward(self, input, target):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
|
||||
self.ignore_index, self.inplace_backward)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
112
flash_attn/losses/cross_entropy_parallel.py
Normal file
112
flash_attn/losses/cross_entropy_parallel.py
Normal file
@ -0,0 +1,112 @@
|
||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
|
||||
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
|
||||
# the losses we can get the global loss. There's no need to do it step by step
|
||||
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_group
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
|
||||
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
|
||||
from apex.transformer.tensor_parallel.utils import VocabUtility
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 4 lines are for backward comparability with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
|
||||
inplace_backward=False):
|
||||
"""
|
||||
logits_parallel: (batch, vocab_size / world_size)
|
||||
labels: (batch,)
|
||||
"""
|
||||
assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
|
||||
batch, partition_vocab_size = logits_parallel.shape
|
||||
assert labels.shape == (batch,)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
partition_vocab_size, get_tensor_model_parallel_rank(),
|
||||
get_tensor_model_parallel_world_size()
|
||||
)
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
|
||||
ignored_mask = labels == ignored_index
|
||||
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
|
||||
masked_labels = labels_local.clone()
|
||||
masked_labels[labels_mask] = ignored_index
|
||||
|
||||
losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
|
||||
assert lse_local.shape == (batch,)
|
||||
assert losses.shape == (batch,)
|
||||
losses.masked_fill_(masked_labels==ignored_index, 0)
|
||||
|
||||
if world_size > 1:
|
||||
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
||||
device=lse_local.device)
|
||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
|
||||
group=get_tensor_model_parallel_group())
|
||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
||||
torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM,
|
||||
group=get_tensor_model_parallel_group())
|
||||
# The losses are currently lse_local - predicted_logit, we just have to subtract the
|
||||
# lse_local and add the lse (global).
|
||||
rank_per_sample = labels // partition_vocab_size
|
||||
lse_local = lse_allgather[rank_per_sample,
|
||||
torch.arange(batch, device=lse_allgather.device)]
|
||||
losses += lse - lse_local
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
else:
|
||||
lse = lse_local
|
||||
|
||||
ctx.save_for_backward(logits_parallel, lse, labels_local)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.inplace_backward = inplace_backward
|
||||
return losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_loss):
|
||||
logits_parallel, lse, labels = ctx.saved_tensors
|
||||
if not grad_loss.is_contiguous():
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
|
||||
ctx.smoothing, ctx.inplace_backward)
|
||||
return grad_logits, None, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLossParallel(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.inplace_backward = inplace_backward
|
||||
|
||||
def forward(self, input, target):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossParallelFn.apply(
|
||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
|
||||
)
|
||||
if self.reduction == 'mean':
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
39
tests/losses/test_cross_entropy_apex.py
Normal file
39
tests/losses/test_cross_entropy_apex.py
Normal file
@ -0,0 +1,39 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from src.losses.cross_entropy_apex import CrossEntropyLossApex
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('inplace_backward', [False, True])
|
||||
# @pytest.mark.parametrize('inplace_backward', [False])
|
||||
@pytest.mark.parametrize('vocab_size', [50257])
|
||||
def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype):
|
||||
device = 'cuda'
|
||||
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 128
|
||||
x_pt = torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True)
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
|
||||
y[torch.randperm(batch_size * seqlen)[:10]] = -100
|
||||
model_pt = torch.nn.CrossEntropyLoss()
|
||||
model = CrossEntropyLossApex(inplace_backward=inplace_backward)
|
||||
out = model(x, y)
|
||||
out_pt = model_pt(x_pt.float(), y)
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
|
||||
|
||||
g = torch.randn_like(out)
|
||||
out_pt.backward(g)
|
||||
out.backward(g)
|
||||
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
|
||||
56
tests/losses/test_cross_entropy_parallel.py
Normal file
56
tests/losses/test_cross_entropy_parallel.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Run test with:
|
||||
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
|
||||
from apex.transformer import parallel_state
|
||||
from apex.transformer import tensor_parallel
|
||||
|
||||
from src.losses.cross_entropy_parallel import CrossEntropyLossParallel
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('inplace_backward', [False, True])
|
||||
# @pytest.mark.parametrize('inplace_backward', [False])
|
||||
@pytest.mark.parametrize('vocab_size', [50264])
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype):
|
||||
assert vocab_size % world_size == 0
|
||||
rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
|
||||
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
partition_vocab_size = vocab_size // world_size
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 128
|
||||
x_pt = (torch.randn(batch_size * seqlen, vocab_size, device=device,
|
||||
dtype=dtype) * 10).requires_grad_()
|
||||
x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
|
||||
y[torch.randperm(batch_size * seqlen)[:10]] = -100
|
||||
model_pt = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
model = CrossEntropyLossParallel(reduction='none', inplace_backward=inplace_backward)
|
||||
out = model(x, y)
|
||||
out_pt = model_pt(x_pt.float(), y)
|
||||
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
|
||||
|
||||
g = torch.randn_like(out)
|
||||
out_pt.backward(g)
|
||||
out.backward(g)
|
||||
assert torch.allclose(x.grad, x_pt.grad[:, (rank * partition_vocab_size):(rank + 1) * partition_vocab_size], rtol=rtol, atol=atol)
|
||||
|
||||
parallel_state.destroy_model_parallel()
|
||||
Loading…
Reference in New Issue
Block a user