761 lines
25 KiB
Plaintext
761 lines
25 KiB
Plaintext
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu
|
|
// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory).
|
|
/**
|
|
* From PyTorch:
|
|
*
|
|
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
|
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
|
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
|
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
|
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
|
* Copyright (c) 2011-2013 NYU (Clement Farabet)
|
|
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
|
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
|
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
|
*
|
|
* From Caffe2:
|
|
*
|
|
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
|
*
|
|
* All contributions by Facebook:
|
|
* Copyright (c) 2016 Facebook Inc.
|
|
*
|
|
* All contributions by Google:
|
|
* Copyright (c) 2015 Google Inc.
|
|
* All rights reserved.
|
|
*
|
|
* All contributions by Yangqing Jia:
|
|
* Copyright (c) 2015 Yangqing Jia
|
|
* All rights reserved.
|
|
*
|
|
* All contributions from Caffe:
|
|
* Copyright(c) 2013, 2014, 2015, the respective contributors
|
|
* All rights reserved.
|
|
*
|
|
* All other contributions:
|
|
* Copyright(c) 2015, 2016 the respective contributors
|
|
* All rights reserved.
|
|
*
|
|
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
|
* copyright over their contributions to Caffe2. The project versioning records
|
|
* all such contribution and copyright details. If a contributor wants to further
|
|
* mark their specific copyright on a particular contribution, they should
|
|
* indicate their copyright solely in the commit message of the change when it is
|
|
* committed.
|
|
*
|
|
* All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright
|
|
* notice, this list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright
|
|
* notice, this list of conditions and the following disclaimer in the
|
|
* documentation and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
|
* and IDIAP Research Institute nor the names of its contributors may be
|
|
* used to endorse or promote products derived from this software without
|
|
* specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
|
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
|
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
|
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
|
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
|
* POSSIBILITY OF SUCH DAMAGE.
|
|
*/
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#include <ATen/AccumulateType.h>
|
|
#include <ATen/cuda/NumericLimits.cuh>
|
|
|
|
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
|
|
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
|
|
switch(TYPE) \
|
|
{ \
|
|
case at::ScalarType::Float: \
|
|
{ \
|
|
using scalar_t_##LEVEL = float; \
|
|
__VA_ARGS__; \
|
|
break; \
|
|
} \
|
|
case at::ScalarType::Half: \
|
|
{ \
|
|
using scalar_t_##LEVEL = at::Half; \
|
|
__VA_ARGS__; \
|
|
break; \
|
|
} \
|
|
case at::ScalarType::BFloat16: \
|
|
{ \
|
|
using scalar_t_##LEVEL = at::BFloat16; \
|
|
__VA_ARGS__; \
|
|
break; \
|
|
} \
|
|
default: \
|
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
|
}
|
|
// #else
|
|
// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
|
|
// switch(TYPE) \
|
|
// { \
|
|
// case at::ScalarType::Float: \
|
|
// { \
|
|
// using scalar_t_##LEVEL = float; \
|
|
// __VA_ARGS__; \
|
|
// break; \
|
|
// } \
|
|
// case at::ScalarType::Half: \
|
|
// { \
|
|
// using scalar_t_##LEVEL = at::Half; \
|
|
// __VA_ARGS__; \
|
|
// break; \
|
|
// } \
|
|
// default: \
|
|
// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
|
// }
|
|
// #endif
|
|
|
|
#define ALIGN_BYTES 16
|
|
|
|
using Tensor = at::Tensor;
|
|
using TensorList = at::TensorList;
|
|
using ScalarType = at::ScalarType;
|
|
using at::acc_type;
|
|
|
|
template<typename T, typename AccumT, typename OutT>
|
|
struct LogSoftMaxForwardEpilogue {
|
|
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
|
|
: logsum(max_input + std::log(sum)) {}
|
|
|
|
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp)
|
|
: logsum(max_log_sum_exp) {}
|
|
|
|
__device__ __forceinline__ OutT operator()(T input) const {
|
|
return static_cast<OutT>(input - logsum);
|
|
}
|
|
|
|
const AccumT logsum;
|
|
};
|
|
|
|
template<typename T, typename AccumT, typename OutT>
|
|
struct LogSoftMaxBackwardEpilogue {
|
|
__device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
|
|
: sum(sum) {}
|
|
|
|
__device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
|
|
return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
|
|
}
|
|
|
|
const AccumT sum;
|
|
};
|
|
|
|
|
|
|
|
const int max_threads = 1024;
|
|
|
|
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
|
|
uint64_t block_size = 1;
|
|
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
|
|
while (block_size < (max_block_size/2)) block_size *= 2;
|
|
// Launch at least a single warp - the kernel assumes that.
|
|
block_size = std::max(block_size, static_cast<uint64_t>(32));
|
|
return dim3(block_size);
|
|
}
|
|
|
|
template<typename T>
|
|
struct Add {
|
|
__device__ __forceinline__ T operator()(T a, T b) const {
|
|
return a + b;
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct Max {
|
|
__device__ __forceinline__ T operator()(T a, T b) const {
|
|
return a < b ? b : a;
|
|
}
|
|
};
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
template <typename T, typename AccumT>
|
|
struct MaxFloat
|
|
{
|
|
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
|
|
return ::max(max, (AccumT)v);
|
|
}
|
|
};
|
|
|
|
template<typename T, typename AccumT>
|
|
struct AddFloat
|
|
{
|
|
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
|
|
return sum + v;
|
|
}
|
|
};
|
|
|
|
template<typename T, typename AccumT>
|
|
struct SumExpFloat
|
|
{
|
|
__device__ __forceinline__ SumExpFloat(AccumT v)
|
|
: max_k(v) {}
|
|
|
|
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
|
|
return sum + std::exp(v - max_k);
|
|
}
|
|
|
|
const AccumT max_k;
|
|
};
|
|
|
|
template <template<typename> class Reduction, typename AccumT>
|
|
__device__ __forceinline__ AccumT
|
|
blockReduce(AccumT* smem, AccumT val,
|
|
const Reduction<AccumT>& r,
|
|
AccumT defaultVal)
|
|
{
|
|
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
|
|
__syncthreads();
|
|
|
|
smem[threadIdx.x] = val;
|
|
|
|
__syncthreads();
|
|
|
|
AccumT warpVal = defaultVal;
|
|
|
|
// First warp will perform per-warp reductions for the remaining warps
|
|
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
|
|
if (threadIdx.x < 32) {
|
|
int lane = threadIdx.x % 32;
|
|
if (lane < blockDim.x / 32) {
|
|
#pragma unroll
|
|
for (int i = 0; i < 32; ++i) {
|
|
warpVal = r(warpVal, smem[lane * 32 + i]);
|
|
}
|
|
__syncwarp(mask);
|
|
smem[lane] = warpVal;
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// First thread will perform a reduction of the above per-warp reductions
|
|
AccumT blockVal = defaultVal;
|
|
|
|
if (threadIdx.x == 0) {
|
|
for (int i = 0; i < blockDim.x / 32; ++i) {
|
|
blockVal = r(blockVal, smem[i]);
|
|
}
|
|
smem[0] = blockVal;
|
|
}
|
|
|
|
// Sync and broadcast
|
|
__syncthreads();
|
|
return smem[0];
|
|
}
|
|
|
|
template <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT>
|
|
__device__ __forceinline__ void
|
|
blockReduce(AccumT* smem,
|
|
AccumT* reducVal1,
|
|
AccumT val1,
|
|
const Reduction1<AccumT>& r1,
|
|
AccumT defaultVal1,
|
|
AccumT* reducVal2,
|
|
AccumT val2,
|
|
const Reduction2<AccumT>& r2,
|
|
AccumT defaultVal2)
|
|
{
|
|
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
|
|
__syncthreads();
|
|
|
|
smem[threadIdx.x] = val1;
|
|
smem[blockDim.x + threadIdx.x] = val2;
|
|
|
|
__syncthreads();
|
|
|
|
AccumT warpVal1 = defaultVal1;
|
|
AccumT warpVal2 = defaultVal2;
|
|
|
|
// First warp will perform per-warp reductions for the remaining warps
|
|
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
|
|
if (threadIdx.x < 32) {
|
|
int lane = threadIdx.x % 32;
|
|
if (lane < blockDim.x / 32) {
|
|
#pragma unroll
|
|
for (int i = 0; i < 32; ++i) {
|
|
warpVal1 = r1(warpVal1, smem[lane * 32 + i]);
|
|
warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]);
|
|
}
|
|
__syncwarp(mask);
|
|
smem[lane] = warpVal1;
|
|
smem[lane + blockDim.x] = warpVal2;
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// First thread will perform a reduction of the above per-warp reductions
|
|
AccumT blockVal1 = defaultVal1;
|
|
AccumT blockVal2 = defaultVal2;
|
|
|
|
if (threadIdx.x == 0) {
|
|
for (int i = 0; i < blockDim.x / 32; ++i) {
|
|
blockVal1 = r1(blockVal1, smem[i]);
|
|
blockVal2 = r2(blockVal2, smem[i + blockDim.x]);
|
|
}
|
|
smem[0] = blockVal1;
|
|
smem[blockDim.x] = blockVal2;
|
|
}
|
|
|
|
// Sync and broadcast
|
|
__syncthreads();
|
|
*reducVal1 = smem[0];
|
|
*reducVal2 = smem[blockDim.x];
|
|
__syncthreads();
|
|
}
|
|
|
|
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
|
|
__device__ __forceinline__ AccumT
|
|
ilpReduce(int shift,
|
|
T* data,
|
|
int size,
|
|
const Reduction<T, AccumT>& r,
|
|
AccumT defaultVal)
|
|
{
|
|
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
|
|
AccumT threadVal = defaultVal;
|
|
int offset = threadIdx.x;
|
|
|
|
// shift and do 1
|
|
if(shift > 0){
|
|
data -= shift;
|
|
size += shift;
|
|
if(threadIdx.x >= shift){
|
|
threadVal = r(threadVal, data[offset]);
|
|
}
|
|
size -= blockDim.x;
|
|
data += blockDim.x;
|
|
}
|
|
int last = size % (ILP * blockDim.x);
|
|
|
|
T v[ILP];
|
|
LoadT* value = reinterpret_cast<LoadT*>(&v);
|
|
|
|
for (; offset * ILP < (size - last); offset += blockDim.x) {
|
|
*value = reinterpret_cast<LoadT*>(data)[offset];
|
|
|
|
for (int j = 0; j < ILP; ++j) {
|
|
threadVal = r(threadVal, v[j]);
|
|
}
|
|
}
|
|
|
|
offset = size - last + threadIdx.x;
|
|
// Epilogue
|
|
for (; offset < size; offset += blockDim.x)
|
|
threadVal = r(threadVal, data[offset]);
|
|
|
|
return threadVal;
|
|
}
|
|
|
|
template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
|
|
__device__ __forceinline__ void
|
|
ilpReduce(int shift,
|
|
T* data,
|
|
int size,
|
|
AccumT* reducVal1,
|
|
const Reduction1<T, AccumT>& r1,
|
|
AccumT defaultVal1,
|
|
AccumT* reducVal2,
|
|
const Reduction2<T, AccumT>& r2,
|
|
AccumT defaultVal2)
|
|
{
|
|
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
|
|
|
|
AccumT threadVal1 = defaultVal1;
|
|
AccumT threadVal2 = defaultVal2;
|
|
int offset = threadIdx.x;
|
|
|
|
// shift and do 1
|
|
if(shift > 0){
|
|
data -= shift;
|
|
size += shift;
|
|
if(threadIdx.x >= shift){
|
|
threadVal1 = r1(threadVal1, data[offset]);
|
|
threadVal2 = r2(threadVal2, data[offset]);
|
|
}
|
|
size -= blockDim.x;
|
|
data += blockDim.x;
|
|
}
|
|
int last = size % (ILP * blockDim.x);
|
|
|
|
T v[ILP];
|
|
LoadT* value = reinterpret_cast<LoadT*>(&v);
|
|
|
|
for (; offset * ILP < (size - last); offset += blockDim.x) {
|
|
*value = reinterpret_cast<LoadT*>(data)[offset];
|
|
|
|
for (int j = 0; j < ILP; ++j) {
|
|
threadVal1 = r1(threadVal1, v[j]);
|
|
threadVal2 = r2(threadVal2, v[j]);
|
|
}
|
|
}
|
|
|
|
offset = size - last + threadIdx.x;
|
|
// Epilogue
|
|
for (; offset < size; offset += blockDim.x) {
|
|
threadVal1 = r1(threadVal1, data[offset]);
|
|
threadVal2 = r2(threadVal2, data[offset]);
|
|
}
|
|
|
|
*reducVal1 = threadVal1;
|
|
*reducVal2 = threadVal2;
|
|
}
|
|
|
|
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
|
|
__global__ void
|
|
cunn_SoftMaxXEntropyForward(
|
|
accscalar_t *losses,
|
|
outscalar_t *max_log_sum_exp,
|
|
scalar_t *input,
|
|
int64_t *labels,
|
|
int64_t classes,
|
|
const float smoothing,
|
|
const int total_classes)
|
|
{
|
|
extern __shared__ unsigned char smem[];
|
|
auto sdata = reinterpret_cast<accscalar_t*>(smem);
|
|
// forward pointers to batch[blockIdx.x]
|
|
// each block handles a sample in the mini-batch
|
|
input += blockIdx.x * classes;
|
|
//output += blockIdx.x * classes;
|
|
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
|
|
|
|
int64_t label = labels[blockIdx.x];
|
|
|
|
// find the max and sum
|
|
accscalar_t threadMax, threadSum, max_k, sum_k;
|
|
ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
|
|
shift, input, classes,
|
|
&threadMax, MaxFloat<scalar_t, accscalar_t>(),
|
|
-at::numeric_limits<accscalar_t>::max(),
|
|
&threadSum, AddFloat<scalar_t, accscalar_t>(),
|
|
static_cast<accscalar_t>(0));
|
|
|
|
blockReduce<Max, Add, accscalar_t>(
|
|
sdata,
|
|
&max_k, threadMax, Max<accscalar_t>(),
|
|
-at::numeric_limits<accscalar_t>::max(),
|
|
&sum_k, threadSum, Add<accscalar_t>(),
|
|
static_cast<accscalar_t>(0));
|
|
|
|
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
|
|
accscalar_t sumAll = blockReduce<Add, accscalar_t>(
|
|
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
|
|
|
|
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
|
|
|
|
// calculate per element loss with label smoothing
|
|
// reserve max + log_sum_exp for bprop
|
|
if (threadIdx.x == 0) {
|
|
accscalar_t lse = max_k + std::log(sumAll);
|
|
accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast<accscalar_t>(input[label])) : 0.f;
|
|
losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing);
|
|
max_log_sum_exp[blockIdx.x] = lse;
|
|
}
|
|
}
|
|
|
|
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
|
|
__device__ __forceinline__ void
|
|
apply(scalar_t *gradInput,
|
|
scalar_t *logits,
|
|
outscalar_t *max_log_sum_exp,
|
|
outscalar_t *gradOutput,
|
|
int64_t *labels,
|
|
const float smoothing,
|
|
int classes,
|
|
const int total_classes)
|
|
{
|
|
accscalar_t smooth_positives = 1.0 - smoothing;
|
|
accscalar_t smooth_negatives = smoothing / total_classes;
|
|
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
|
|
int64_t label = labels[blockIdx.x];
|
|
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
|
|
|
|
int offset = threadIdx.x;
|
|
int last = classes % (ILP * blockDim.x);
|
|
|
|
for (; offset < classes - last; offset += blockDim.x * ILP) {
|
|
accscalar_t tmpLogits[ILP];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < ILP; ++j) {
|
|
tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < ILP; ++j)
|
|
gradInput[offset + j * blockDim.x] = tmpGradOutput * (
|
|
std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
|
|
(offset + j * blockDim.x == label) ? 1 : 0) *
|
|
smooth_positives - smooth_negatives);
|
|
}
|
|
|
|
for (; offset < classes; offset += blockDim.x)
|
|
gradInput[offset] = tmpGradOutput * (std::exp(
|
|
static_cast<accscalar_t>(logits[offset]) - coeff) -
|
|
static_cast<accscalar_t>((offset == label) ? 1 : 0) *
|
|
smooth_positives - smooth_negatives);
|
|
}
|
|
|
|
|
|
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
|
|
__device__ __forceinline__ void
|
|
aligned_apply(int shift,
|
|
scalar_t *gradInput,
|
|
scalar_t *logits,
|
|
outscalar_t *max_log_sum_exp,
|
|
outscalar_t *gradOutput,
|
|
int64_t *labels,
|
|
const float smoothing,
|
|
int classes,
|
|
const int total_classes)
|
|
{
|
|
accscalar_t smooth_positives = 1.0 - smoothing;
|
|
accscalar_t smooth_negatives = smoothing / total_classes;
|
|
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
|
|
int64_t label = labels[blockIdx.x];
|
|
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
|
|
|
|
int offset = threadIdx.x;
|
|
|
|
// shift and do 1
|
|
if(shift > 0){
|
|
logits -= shift;
|
|
gradInput -= shift;
|
|
classes += shift;
|
|
if(threadIdx.x >= shift){
|
|
gradInput[offset] = tmpGradOutput * (std::exp(
|
|
static_cast<accscalar_t>(logits[offset]) - coeff) -
|
|
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
|
|
smooth_positives - smooth_negatives);
|
|
}
|
|
classes -= blockDim.x;
|
|
gradInput += blockDim.x;
|
|
logits += blockDim.x;
|
|
shift -= blockDim.x;
|
|
}
|
|
|
|
int last = classes % (ILP * blockDim.x);
|
|
|
|
typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;
|
|
// input
|
|
scalar_t v[ILP];
|
|
LoadT* value = reinterpret_cast<LoadT*>(&v);
|
|
// output
|
|
scalar_t r[ILP];
|
|
LoadT* result = reinterpret_cast<LoadT*>(&r);
|
|
|
|
for (; offset * ILP < (classes - last); offset += blockDim.x) {
|
|
*value = reinterpret_cast<LoadT*>(logits)[offset];
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j < ILP; ++j) {
|
|
r[j] = tmpGradOutput * (std::exp(
|
|
static_cast<accscalar_t>(v[j]) - coeff) -
|
|
static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *
|
|
smooth_positives - smooth_negatives);
|
|
}
|
|
reinterpret_cast<LoadT*>(gradInput)[offset] = *result;
|
|
}
|
|
|
|
offset = classes - last + threadIdx.x;
|
|
for (; offset < classes; offset += blockDim.x)
|
|
gradInput[offset] = tmpGradOutput * (std::exp(
|
|
static_cast<accscalar_t>(logits[offset]) - coeff) -
|
|
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
|
|
smooth_positives - smooth_negatives);
|
|
|
|
}
|
|
|
|
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
|
|
__global__ void
|
|
cunn_SoftMaxXEntropyBackward(
|
|
scalar_t *gradInput,
|
|
scalar_t *logits,
|
|
outscalar_t *max_log_sum_exp,
|
|
outscalar_t *gradOutput,
|
|
int64_t *labels,
|
|
const float smoothing,
|
|
int classes,
|
|
const int total_classes)
|
|
{
|
|
gradInput += blockIdx.x * classes;
|
|
logits += blockIdx.x * classes;
|
|
|
|
// Do vectorized load/store when input/output have same alignment
|
|
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
|
|
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
|
|
if (shift == shift_){
|
|
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
|
|
}
|
|
else {
|
|
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
|
|
}
|
|
|
|
}
|
|
|
|
template<template<typename, typename, typename> class Epilogue>
|
|
std::vector<Tensor> host_softmax_xentropy(
|
|
const Tensor & input_,
|
|
const Tensor & labels_,
|
|
const float smoothing,
|
|
const int total_classes) {
|
|
// For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
|
// of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
|
// last dimension of the input tensor.
|
|
AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
|
|
|
|
// Otherwise the kernel will be launched from cuda:0 device
|
|
// Cast to char to avoid compiler warning about narrowing
|
|
at::cuda::CUDAGuard device_guard{(char)input_.get_device()};
|
|
|
|
auto input = input_.contiguous();
|
|
Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float));
|
|
Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float));
|
|
|
|
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
|
|
std::is_same<acc_type<at::Half, true>, double>::value,
|
|
"accscalar_t for half should be float or double");
|
|
AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported");
|
|
AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional");
|
|
AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples");
|
|
AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0");
|
|
|
|
const int64_t dim = 1;
|
|
int64_t outer_size = 1;
|
|
int64_t dim_size = input.size(dim);
|
|
int64_t inner_size = 1;
|
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
for (int64_t i = 0; i < dim; ++i)
|
|
outer_size *= input.size(i);
|
|
for (int64_t i = dim + 1; i < input.dim(); ++i)
|
|
inner_size *= input.size(i);
|
|
// This kernel spawns a block per each element in the batch.
|
|
// XXX: it assumes that inner_size == 1
|
|
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
|
|
|
|
dim3 grid(outer_size);
|
|
|
|
using namespace at;
|
|
DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy",
|
|
using accscalar_t = at::acc_type<scalar_t_0, true>;
|
|
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
|
|
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
|
|
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
|
|
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
|
|
losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
|
|
input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
|
|
dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes
|
|
);
|
|
);
|
|
|
|
C10_CUDA_CHECK(cudaGetLastError());
|
|
|
|
std::vector<at::Tensor> ret = {losses, max_log_sum_exp};
|
|
return ret;
|
|
}
|
|
|
|
template<template<typename, typename, typename> class Epilogue>
|
|
Tensor host_softmax_xentropy_backward(
|
|
const at::Tensor &grad_loss,
|
|
at::Tensor &logits_,
|
|
const at::Tensor &max_log_sum_exp,
|
|
const at::Tensor &labels,
|
|
const float smoothing,
|
|
bool inplace,
|
|
const int total_classes) {
|
|
// Otherwise the kernel will be launched from cuda:0 device
|
|
// Cast to char to avoid compiler warning about narrowing
|
|
at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()};
|
|
|
|
const int64_t dim = 1;
|
|
Tensor gI = inplace ? logits_ : at::empty_like(logits_);
|
|
if (grad_loss.numel() == 0) {
|
|
return gI;
|
|
}
|
|
|
|
auto grad = grad_loss.contiguous();
|
|
auto logits = logits_.contiguous();
|
|
|
|
static_assert(std::is_same<acc_type<at::Half, true>, float>::value ||
|
|
std::is_same<acc_type<at::Half, true>, double>::value,
|
|
"accscalar_t for half should be float or double");
|
|
if (grad.dim() == 0) grad = grad.view(1);
|
|
|
|
AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported");
|
|
AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional");
|
|
AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0");
|
|
AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples");
|
|
AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples");
|
|
|
|
int64_t outer_size = 1;
|
|
int64_t dim_size = logits.size(dim);
|
|
int64_t inner_size = 1;
|
|
for (int64_t i = 0; i < dim; ++i)
|
|
outer_size *= logits.size(i);
|
|
for (int64_t i = dim + 1; i < logits.dim(); ++i)
|
|
inner_size *= logits.size(i);
|
|
// See descriptions of kernels above.
|
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
|
|
|
|
dim3 grid(outer_size);
|
|
|
|
DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
|
|
using accscalar_t = acc_type<scalar_t_0, true>;
|
|
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
|
|
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
|
|
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue>
|
|
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
|
|
gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
|
|
max_log_sum_exp.data_ptr<accscalar_t>(),
|
|
grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
|
|
smoothing, dim_size, total_classes
|
|
);
|
|
);
|
|
|
|
C10_CUDA_CHECK(cudaGetLastError());
|
|
return gI;
|
|
}
|
|
|
|
std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){
|
|
return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, total_classes);
|
|
}
|
|
|
|
at::Tensor softmax_xentropy_backward_cuda(
|
|
const at::Tensor &grad_loss,
|
|
at::Tensor &logits,
|
|
const at::Tensor &max_log_sum_exp,
|
|
const at::Tensor &labels,
|
|
const float smoothing,
|
|
const bool inplace,
|
|
const int total_classes) {
|
|
AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
|
|
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes);
|
|
}
|