2023-07-17 20:26:11 +08:00
|
|
|
/******************************************************************************
|
|
|
|
|
* Copyright (c) 2023, Tri Dao.
|
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <cmath>
|
|
|
|
|
|
|
|
|
|
#include <cute/tensor.hpp>
|
|
|
|
|
|
2023-09-13 03:37:10 +08:00
|
|
|
#include <cutlass/numeric_types.h>
|
2023-07-17 20:26:11 +08:00
|
|
|
|
|
|
|
|
#include "philox.cuh"
|
|
|
|
|
#include "utils.h"
|
|
|
|
|
|
|
|
|
|
namespace flash {
|
|
|
|
|
|
|
|
|
|
using namespace cute;
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
|
|
|
|
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
|
|
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
|
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
|
|
|
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
|
|
|
|
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
|
|
|
|
summary(mi) = op(summary(mi), tensor(mi, ni));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
|
|
|
|
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
|
|
|
|
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < size(dst); i++){
|
|
|
|
|
dst(i) = Allreduce<4>::run(src(i), op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
|
|
|
|
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
|
|
|
|
thread_reduce_<zero_init>(tensor, summary, op);
|
|
|
|
|
quad_allreduce_(summary, summary, op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
|
|
|
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
|
|
|
|
MaxOp<float> max_op;
|
|
|
|
|
reduce_<zero_init>(tensor, max, max_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
|
|
|
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
|
|
|
|
SumOp<float> sum_op;
|
|
|
|
|
reduce_(tensor, sum, sum_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Apply the exp to all the elements.
|
|
|
|
|
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
|
|
|
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
|
|
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
|
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
|
|
|
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
|
|
|
|
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
|
|
|
|
// We don't want (-inf - (-inf)) since that would give NaN.
|
|
|
|
|
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
|
|
|
|
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
|
|
|
|
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
|
|
|
|
// max * log_2(e)) This allows the compiler to use the ffma
|
|
|
|
|
// instruction instead of fadd and fmul separately.
|
|
|
|
|
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Apply the exp to all the elements.
|
|
|
|
|
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
|
|
|
|
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
|
|
|
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
|
|
|
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
|
|
|
|
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
|
|
|
|
MaxOp<float> max_op;
|
|
|
|
|
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
|
|
|
|
max(mi) = max_op(max(mi), tensor(mi, ni));
|
|
|
|
|
}
|
|
|
|
|
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
|
|
|
|
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
|
|
|
|
// We don't want (-inf - (-inf)) since that would give NaN.
|
|
|
|
|
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
|
|
|
|
sum(mi) = 0;
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
|
|
|
|
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
|
|
|
|
// max * log_2(e)) This allows the compiler to use the ffma
|
|
|
|
|
// instruction instead of fadd and fmul separately.
|
|
|
|
|
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
|
|
|
|
sum(mi) += tensor(mi, ni);
|
|
|
|
|
}
|
|
|
|
|
SumOp<float> sum_op;
|
|
|
|
|
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace flash
|