From 04a9777b87cc89f2f7056cfeff739ec1f4950da3 Mon Sep 17 00:00:00 2001 From: Yujia Zhai Date: Fri, 1 Jul 2022 22:19:18 -0700 Subject: [PATCH] Softmax (#546) * add test layernorm g-mem version * Delete include/configure directory * Delete examples/test_layernorm directory * Update gemm_with_softmax.h * Update gemm_softmax.cu * Update linear_combination.h * Update fast_math.h * remove redundant vars Co-authored-by: yujia.zhai Co-authored-by: yuzhai --- examples/35_gemm_softmax/gemm_softmax.cu | 62 +- examples/35_gemm_softmax/gemm_with_softmax.h | 733 +++++++++--------- .../epilogue/thread/linear_combination.h | 2 +- include/cutlass/fast_math.h | 7 +- 4 files changed, 428 insertions(+), 376 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 7c32232d..0d18077e 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -55,6 +55,7 @@ #include "cutlass/util/reference/host/error_metrics.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/epilogue/thread/linear_combination.h" ///////////////////////////////////////////////////////////////////////////////////////////////// #include "gemm_with_softmax.h" @@ -204,14 +205,24 @@ struct Testbed { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; + /// Linear scaling operator + using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementCompute, + ElementCompute + >; + using GemmSoftmax = cutlass::GemmSoftmax< ElementA, LayoutA, ElementB, LayoutB, ElementC, - ElementCompute + ElementCompute, + EpilogueFunctorOp >; - using ElementN = typename GemmSoftmax::ElementN; + using ElementNorm = typename GemmSoftmax::ElementNorm; + using ElementSum = typename GemmSoftmax::ElementSum; using LayoutC = typename GemmSoftmax::LayoutC; // @@ -224,13 +235,16 @@ struct Testbed { cutlass::HostTensor tensor_B; cutlass::HostTensor tensor_C; cutlass::HostTensor tensor_D; - cutlass::HostTensor tensor_N; + cutlass::HostTensor tensor_N; + cutlass::HostTensor tensor_S; cutlass::HostTensor tensor_Softmax; cutlass::HostTensor reference_D; - cutlass::HostTensor reference_N; + cutlass::HostTensor reference_N; cutlass::HostTensor reference_Softmax; + int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN; + // // Methods // @@ -247,7 +261,8 @@ struct Testbed { tensor_C.reset({options.problem_size.m(), options.problem_size.n()}); tensor_D.reset({options.problem_size.m(), options.problem_size.n()}); - tensor_N.reset({options.problem_size.m(), 1}); + tensor_N.reset({block_num, options.problem_size.m()}); + tensor_S.reset({block_num, options.problem_size.m()}); tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()}); reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false); @@ -342,7 +357,7 @@ struct Testbed { cutlass::reference::host::TensorFill( reference_N.host_view(), - ElementN() + ElementNorm() ); cutlass::reference::host::TensorFill( @@ -354,6 +369,7 @@ struct Testbed { tensor_B.sync_device(); tensor_D.sync_device(); tensor_N.sync_device(); + tensor_S.sync_device(); tensor_Softmax.sync_device(); } @@ -377,6 +393,7 @@ struct Testbed { ElementCompute(options.beta) }, tensor_N.device_ref(), + tensor_S.device_ref(), tensor_Softmax.device_ref() ); @@ -420,7 +437,7 @@ struct Testbed { for (int m = 0; m < options.problem_size.m(); ++m) { reference_N.at({m, 0}) = reference_D.at({m, 0}); for (int n = 1; n < options.problem_size.n(); ++n) { - reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementN(reference_D.at({m, n}))); + reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(reference_D.at({m, n}))); } } @@ -454,6 +471,20 @@ struct Testbed { std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n"; } + bool verify_tensor_N(cutlass::HostTensor tensor_N, \ + cutlass::HostTensor reference_N) { + + for (int m = 0; m < options.problem_size.m(); ++m) { + float diff = (float)(tensor_N.at({0, m}) - reference_N.at({m, 0})); + if (fabs(diff) > options.tolerance) { + return false; + } + + } + + return true; + } + /// Verifies the reference matches bool verify() { @@ -489,22 +520,7 @@ struct Testbed { } if (!verified_N) { - - double norm_diff = cutlass::reference::host::TensorNormDiff( - tensor_N.host_view(), - reference_N.host_view()); - - double norm_reference = cutlass::reference::host::TensorNorm( - reference_N.host_view()); - - double rel_error = norm_diff / norm_reference; - - if (rel_error > kThreshold) { - std::cerr << "\n\nTensor N Relative error: " << rel_error << std::endl; - } - else { - verified_N = true; - } + verified_N = verify_tensor_N(tensor_N, reference_N); } if (!verified_Softmax) { diff --git a/examples/35_gemm_softmax/gemm_with_softmax.h b/examples/35_gemm_softmax/gemm_with_softmax.h index 638017cc..213f8c5a 100644 --- a/examples/35_gemm_softmax/gemm_with_softmax.h +++ b/examples/35_gemm_softmax/gemm_with_softmax.h @@ -72,9 +72,10 @@ namespace kernel { // template < typename ElementD_, - typename ElementN_, + typename ElementNorm_, typename ElementSum_, typename ElementSoft_, + typename ElementSoftmaxCompute_, int Alignment, typename Shape_ = MatrixShape<4, 16> > @@ -82,9 +83,10 @@ class ApplySoftmax { public: using ElementD = ElementD_; - using ElementN = ElementN_; + using ElementNorm = ElementNorm_; using ElementSum = ElementSum_; using ElementSoft = ElementSoft_; + using ElementSoftmaxCompute = ElementSoftmaxCompute_; static int const kAlignment = Alignment; using Shape = Shape_; @@ -92,11 +94,11 @@ public: using Layout = cutlass::layout::RowMajor; using TensorRefD = TensorRef; - using TensorRefN = TensorRef; + using TensorRefN = TensorRef; using TensorRefSum = TensorRef; using TensorRefSoft = TensorRef; - using FragmentSum = Array; + using FragmentSoftmax = Array; // // Arguments @@ -108,9 +110,11 @@ public: int batch_count; ///< Batch count TensorRefD ref_D; ///< D matrix computed by GEMM+Max (input) TensorRefN ref_N; ///< Norm tensor (input) + TensorRefSum ref_S; ///< Sum tensor (input) TensorRefSoft ref_Soft; ///< Softmax tensor (output) int64_t batch_stride_D; ///< Batch stride for D tensor int64_t batch_stride_N; ///< Batch stride for N tensor + int64_t batch_stride_S; ///< Batch stride for S tensor int64_t batch_stride_Soft; ///< Batch stride for softmax tensor // @@ -120,6 +124,7 @@ public: batch_count(1), batch_stride_D(0), batch_stride_N(0), + batch_stride_S(0), batch_stride_Soft(0) { } @@ -128,18 +133,22 @@ public: int batch_count_, ///< Batch count TensorRefD ref_D_, ///< D matrix computed by GEMM+PartialReduce TensorRefN ref_N_, ///< Output parameter for N + TensorRefSum ref_S_, ///< Output parameter for N TensorRefSoft ref_Soft_, ///< Softmax int64_t batch_stride_D_ = 0, int64_t batch_stride_N_ = 0, + int64_t batch_stride_S_ = 0, int64_t batch_stride_Soft_ = 0 ): extent(extent_), batch_count(batch_count_), ref_D(ref_D_), ref_N(ref_N_), + ref_S(ref_S_), ref_Soft(ref_Soft_), batch_stride_D(batch_stride_D_), batch_stride_N(batch_stride_N_), + batch_stride_S(batch_stride_S_), batch_stride_Soft(batch_stride_Soft_) { @@ -167,10 +176,6 @@ public: struct SharedStorage { - AlignedArray exchange; - AlignedArray inv_sum; - AlignedArray norm; - }; private: @@ -182,246 +187,11 @@ public: CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Phase 1. Reduction over contiguous dimension - reduce_partial(params, shared_storage); - - __syncthreads(); - - // Phase 2. Final reduction within SMEM - yields sum_n(exp(D - N)) - reduce_final(params, shared_storage); - - __syncthreads(); - - // Phase 3. Apply apply(params, shared_storage); } private: - /// Partial reduction - CUTLASS_DEVICE - void reduce_partial(Params const ¶ms, SharedStorage &shared_storage) { - - // - // Sum over the matrix - // - using AccessTypeD = AlignedArray; - - int block_batch = blockIdx.z; - int block_m = blockIdx.x * Shape::kRow; - int block_n = 0; - - int thread_m = threadIdx.y; - int thread_n = threadIdx.x * kAlignment; - - int idx_m = block_m + thread_m; - int idx_n = block_n + thread_n; - - AccessTypeD *access_d = reinterpret_cast( - params.args.ref_D.data() + - params.args.batch_stride_D * block_batch + - params.args.ref_D.layout()({idx_m, idx_n})); - - using ConvertS = cutlass::NumericArrayConverter; - - using Plus = cutlass::plus; - using Minus = cutlass::minus; - using Exp = cutlass::fast_exp_op; - - ConvertS convert_s; - Minus minus; - Plus plus; - Exp exponential; - - FragmentSum frag_Sum; - frag_Sum.clear(); - - if (idx_m < params.args.extent.row()) { - - // Fetch the norm from GlobalMemory - ElementN norm = params.args.ref_N.data()[params.args.batch_stride_N * block_batch + idx_m]; - ElementSum norm_cvt = ElementSum(norm); - - FragmentSum norm_vec; - - norm_vec.fill(norm_cvt); - shared_storage.norm[thread_m] = ElementSum(norm_cvt); - - CUTLASS_PRAGMA_UNROLL - for ( - int idx = 0; - idx < params.args.extent.column(); - idx += Shape::kColumn * kAlignment) { - - if (idx_n < params.args.extent.column()) { - - AccessTypeD fetch; - arch::global_load(fetch, access_d, true); - - auto tmp = exponential(minus(convert_s(fetch), norm_vec)); - - frag_Sum = plus(frag_Sum, tmp); - } - - access_d += Shape::kColumn; - idx_n += Shape::kColumn * kAlignment; - } - - // Sum the elements owned by one thread - ElementSum sum = frag_Sum[0]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kAlignment; ++i) { - sum += frag_Sum[i]; - } - - shared_storage.exchange.data()[threadIdx.x + threadIdx.y * Shape::kColumn] = sum; - } - } - - /// Compute the final summation from data in SMEM - CUTLASS_DEVICE - void reduce_final(Params const ¶ms, SharedStorage &shared_storage) { - - // - // SMEM has shape `Shape::Row`-by-`Shape::Column` - // - // This computes a reduction across the `Column` dimension yielding a `Row-by-1` vector. - // - - #if true - // - // Tuning parameters tradeoff ILP with latency - // - // During each step of the reduction, each thread performs `kAccesses` of vector size `kReduceVector` - - // Tune the number of accesses per reduction - int const kAccesses = 2; - - // Tune the memory access size - int const kReduceVector = 4; - - // - // Static asserts to ensure integrity - // - - static_assert(kAccesses * kReduceVector, - "Zero-size steps would infinitely loop."); - - static_assert( - is_pow2::value && - is_pow2::value && - is_pow2::value, - "Powers of two only."); - - static_assert(!(Shape::kColumn % (kAccesses * kReduceVector)), - "Divisibility not satisfied"); - - // - // Reduction operators - // - - using FragmentSum = Array; - using Plus = cutlass::plus; - - Plus plus; - - // Tree reduction - ElementSum *smem_ptr = shared_storage.exchange.data() + threadIdx.y * Shape::kColumn; - - ElementSum final = ElementSum(); - - CUTLASS_PRAGMA_UNROLL - for ( - int tidx_limit = Shape::kColumn / (kAccesses * kReduceVector); - tidx_limit > 0; - tidx_limit /= (kAccesses * kReduceVector)) { - - if (threadIdx.x < tidx_limit) { - FragmentSum fetch; - - arch::shared_load( - &fetch, - arch::cutlass_get_smem_pointer(smem_ptr + threadIdx.x * kReduceVector)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kAccesses; ++i) { - FragmentSum extra; - - arch::shared_load( - &extra, - arch::cutlass_get_smem_pointer( - smem_ptr + threadIdx.x * kReduceVector + tidx_limit * kReduceVector * i)); - - fetch = plus(fetch, extra); - } - - // Reduce to one element - final = fetch[0]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kReduceVector; ++i) { - final += fetch[i]; - } - } - __syncthreads(); - - if (threadIdx.x < tidx_limit) { - smem_ptr[threadIdx.x] = final; - } - __syncthreads(); - } - - if (threadIdx.x == 0) { - - int const kLgResidual = - (log2_down::value % log2_down::value); - - // Certain shape combinations require an additional reduction step - if (kLgResidual) { - final = ElementSum(); - - int const kResidualVector = (1 << kLgResidual); - Array fetch; - - arch::shared_load( - &fetch, - arch::cutlass_get_smem_pointer(smem_ptr)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kResidualVector; ++i) { - final += fetch[i]; - } - } - - // compute inverse - ElementSum inv_sum = cutlass::constants::one() / final; - - // Store to shared memory - shared_storage.inv_sum[threadIdx.y] = inv_sum; - } - - #else - - // Primitive serial reduction - if (threadIdx.x < Shape::kRow && threadIdx.y == 0) { - ElementSum *smem_ptr = shared_storage.exchange.data() + threadIdx.x * Shape::kColumn; - - ElementSum sum = smem_ptr[0]; - CUTLASS_PRAGMA_UNROLL - for (int n = 1; n < Shape::kColumn; ++n) { - sum += smem_ptr[n]; - } - - // compute inverse - ElementSum inv_sum = cutlass::constants::one() / sum; - - // Store to shared memory - shared_storage.inv_sum[threadIdx.x] = inv_sum; - } - #endif - } /// Compute Softmax CUTLASS_DEVICE @@ -451,20 +221,26 @@ private: using AccessTypeD = AlignedArray; using AccessTypeSoft = AlignedArray; using FragmentSoft = Array; - using ConvertSum = cutlass::NumericArrayConverter; - using ConvertSoft = cutlass::NumericArrayConverter; + using ConvertSoftCompute = cutlass::NumericArrayConverter; + using ConvertSoftOutput = cutlass::NumericArrayConverter; - using Mul = cutlass::multiplies; - using Minus = cutlass::minus; - using Exp = cutlass::fast_exp_op; + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; - ConvertSum convert_sum; - ConvertSoft convert_soft; + ConvertSoftCompute convert_soft_compute; + ConvertSoftOutput convert_soft_output; Minus minus; Mul mul; Exp exponential; + using ConvertSum = cutlass::NumericConverter; + using ConvertNorm = cutlass::NumericConverter; + + ConvertSum convert_sum; + ConvertNorm convert_norm; + AccessTypeD *access_d = reinterpret_cast( params.args.ref_D.data() + params.args.batch_stride_D * block_batch + @@ -475,11 +251,8 @@ private: params.args.batch_stride_Soft * block_batch + params.args.ref_Soft.layout()({idx_m, idx_n})); - // Fetch inv_sum from SharedMemory - ElementSum inv_sum = shared_storage.inv_sum[thread_m]; - - // Fetch the norm from SharedMemory - ElementSum norm = shared_storage.norm[thread_m]; + ElementSum inv_sum = (params.args.ref_S.data())[block_m]; + ElementNorm norm = (params.args.ref_N.data())[block_m]; // // Loop @@ -495,8 +268,8 @@ private: AccessTypeD fetch; arch::global_load(fetch, access_d, true); - FragmentSum result = mul(exponential(minus(convert_sum(fetch), norm)), inv_sum); - FragmentSoft soft = convert_soft(result); + FragmentSoftmax result = mul(exponential(minus(convert_soft_compute(fetch), convert_norm(norm))), convert_sum(inv_sum)); + FragmentSoft soft = convert_soft_output(result); arch::global_store(soft, access_soft, true); } @@ -508,6 +281,173 @@ private: } }; +template < + typename ElementNorm_, + typename ElementSum_, + typename ElementSoftmaxCompute_, + typename ThreadblockShape_ +> +class ApplyFinalReduction { +public: + + using ElementNorm = ElementNorm_; + using ElementSum = ElementSum_; + using ElementSoftmaxCompute = ElementSoftmaxCompute_; + using ThreadblockShape = ThreadblockShape_; + + using Layout = cutlass::layout::RowMajor; + + using TensorRefN = TensorRef; + using TensorRefSum = TensorRef; + + // + // Arguments + // + + struct Arguments { + + MatrixCoord extent; ///< Extent of D and Softmax matrices + int batch_count; ///< Batch count + TensorRefN ref_N; ///< Norm tensor (input / output) + TensorRefSum ref_Sum; ///< Sum tensor (input / output) + int64_t batch_stride_N; ///< Batch stride for N tensor + int64_t batch_stride_Sum; ///< Batch stride for softmax tensor + + // + // Methods + // + Arguments(): + batch_count(1), + batch_stride_N(0), + batch_stride_Sum(0) + { } + + Arguments( + MatrixCoord extent_, ///< Extent of D and Softmax matrices + int batch_count_, ///< Batch count + TensorRefN ref_N_, ///< Output parameter for N + TensorRefSum ref_Sum_ , ///< Sum + int64_t batch_stride_N_ = 0, + int64_t batch_stride_Sum_ = 0 + ): + extent(extent_), + batch_count(batch_count_), + ref_N(ref_N_), + ref_Sum(ref_Sum_), + batch_stride_N(batch_stride_N_), + batch_stride_Sum(batch_stride_Sum_) + { + + } + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + ApplyFinalReduction() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + apply(params, shared_storage); + } + +private: + + /// Partial reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + int threadblock_num = (params.args.extent.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + + int block_batch = blockIdx.z; + + int block_n = blockIdx.x * blockDim.x; + + int thread_n = threadIdx.x; + + int idx_n = block_n + thread_n; + + if (idx_n >= params.args.extent.row()) { + return; + } + + + using ConvertSumOutput = cutlass::NumericConverter; + using ConvertNormOutput = cutlass::NumericConverter; + + using ConvertSum = cutlass::NumericConverter; + using ConvertNorm = cutlass::NumericConverter; + + ConvertSum convert_sum; + ConvertNorm convert_norm; + + ConvertSumOutput convert_sum_output; + ConvertNormOutput convert_norm_output; + + ElementNorm *access_n = params.args.ref_N.data() + params.args.batch_stride_N * block_batch + idx_n; + ElementSum *access_s = params.args.ref_Sum.data() + params.args.batch_stride_Sum * block_batch + idx_n; + + ElementNorm *access_n_bak = access_n; + ElementSum *access_s_bak = access_s; + + uint32_t float_max_bits = 0xff7fffff; + float min_float = reinterpret_cast(float_max_bits); + + ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); + ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); + ElementNorm fetch_n; + ElementSum fetch_s; + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { + arch::global_load(fetch_n, access_n, true); + max_val = fast_max(max_val, convert_norm(fetch_n)); + access_n += params.args.extent.row(); + } + + access_n = access_n_bak; + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { + arch::global_load(fetch_n, access_n, true); + arch::global_load(fetch_s, access_s, true); + sum_val += convert_sum(fetch_s) * fast_exp(convert_norm(fetch_n) - max_val); + access_n += params.args.extent.row(); + access_s += params.args.extent.row(); + } + + ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; + + access_n = access_n_bak; + access_s = access_s_bak; + + access_n[0] = convert_norm_output(max_val); + access_s[0] = convert_sum_output(inv_sum); + } +}; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -516,6 +456,9 @@ template < int ThreadCount, typename OutputTileIterator_, typename ElementAccumulator_, + typename ElementNorm_, + typename ElementSum_, + typename ElementSoftmaxCompute_, typename ElementwiseFunctor_ > class EpilogueVisitorBiasMax { @@ -532,10 +475,14 @@ public: using ElementOutput = typename OutputTileIterator::Element; using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; + using ElementNorm = ElementNorm_; + using ElementSum = ElementSum_; + using ElementSoftmaxCompute = ElementSoftmaxCompute_; + using AccumulatorFragment = Array; + using SoftmaxFragment = Array; using OutputVector = Array; using TensorRefD = TensorRef; @@ -545,19 +492,23 @@ public: typename ElementwiseFunctor::Params elementwise; TensorRefD ref_C; TensorRefD ref_D; - float *ptr_Max; + ElementNorm *ptr_Max; + ElementSum *ptr_Sum; int64_t batch_stride_C; int64_t batch_stride_D; int64_t batch_stride_Max; + int64_t batch_stride_Sum; // // Methods // Arguments(): ptr_Max(nullptr), + ptr_Sum(nullptr), batch_stride_C(0), batch_stride_D(0), - batch_stride_Max(0) + batch_stride_Max(0), + batch_stride_Sum(0) { } @@ -566,18 +517,22 @@ public: typename ElementwiseFunctor::Params elementwise_, TensorRefD ref_C_, TensorRefD ref_D_, - float *ptr_Max_, + ElementNorm *ptr_Max_, + ElementSum *ptr_Sum_, int64_t batch_stride_C_, int64_t batch_stride_D_, - int64_t batch_stride_Max_ + int64_t batch_stride_Max_, + int64_t batch_stride_Sum_ ): elementwise(elementwise_), ref_C(ref_C_), ref_D(ref_D_), ptr_Max(ptr_Max_), + ptr_Sum(ptr_Sum_), batch_stride_C(batch_stride_C_), batch_stride_D(batch_stride_D_), - batch_stride_Max(batch_stride_Max_) + batch_stride_Max(batch_stride_Max_), + batch_stride_Sum(batch_stride_Sum_) { } @@ -590,10 +545,12 @@ public: typename OutputTileIterator::Params params_D; typename OutputTileIterator::Element *ptr_C; typename OutputTileIterator::Element *ptr_D; - float *ptr_Max; + ElementNorm *ptr_Max; + ElementSum *ptr_Sum; int64_t batch_stride_C; int64_t batch_stride_D; int64_t batch_stride_Max; + int64_t batch_stride_Sum; // // Methods @@ -601,7 +558,8 @@ public: CUTLASS_HOST_DEVICE Params(): ptr_D(nullptr), - ptr_Max(nullptr) + ptr_Max(nullptr), + ptr_Sum(nullptr) { } @@ -614,9 +572,11 @@ public: ptr_C(args.ref_C.data()), ptr_D(args.ref_D.data()), ptr_Max(args.ptr_Max), + ptr_Sum(args.ptr_Sum), batch_stride_C(args.batch_stride_C), batch_stride_D(args.batch_stride_D), - batch_stride_Max(args.batch_stride_Max) + batch_stride_Max(args.batch_stride_Max), + batch_stride_Sum(args.batch_stride_Sum) { } @@ -624,7 +584,7 @@ public: /// Shared storage struct SharedStorage { - float reduction[ThreadblockShape::kM]; + }; private: @@ -642,7 +602,7 @@ private: ElementAccumulator alpha_; ElementAccumulator beta_; - ElementAccumulator accum_max_; + ElementSoftmaxCompute accum_max_; int threadblock_row_; public: @@ -692,14 +652,6 @@ public: CUTLASS_DEVICE void begin_epilogue() { - int const kStoreCount = (ThreadblockShape::kM + kThreadCount - 1) / kThreadCount; - - clear_accum_max_(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kStoreCount; ++i) { - shared_storage_.reduction[i * kThreadCount + threadIdx.x] = accum_max_; - } } /// Called at the start of one step before starting accumulator exchange @@ -708,8 +660,11 @@ public: fragment_D_.clear(); fragment_C_.clear(); - iterator_C_.load(fragment_C_); - ++iterator_C_; + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } /// Called at the start of a row @@ -726,52 +681,87 @@ public: int frag_idx, AccumulatorFragment const &accum) { - NumericArrayConverter source_converter; + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + Minus minus; + Exp exponential; + + SoftmaxFragment result; + + using ConvertSumOutput = cutlass::NumericConverter; + using ConvertNormOutput = cutlass::NumericConverter; + + ConvertSumOutput convert_sum_output; + ConvertNormOutput convert_norm_output; + + NumericArrayConverter source_converter; OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; - AccumulatorFragment source = source_converter(source_vector); - AccumulatorFragment result = alpha_ * accum + beta_ * source; + if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + result = source_converter(elementwise_(accum)); + }else{ + result = source_converter(elementwise_(accum, source_vector)); + } MatrixCoord thread_offset = iterator_D_.thread_start() + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); + int thread_in_row = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; + int half_thread_in_row = (thread_in_row >> 1); + bool column_guard = (thread_offset.column() < extent_.column()); // Compute the maximum within one row if (!column_idx) { - // This is the first fragment in a new row if (column_guard) { - accum_max_ = maximum_accumulator_(accum); + accum_max_ = maximum_accumulator_(result); } } else { - // This is an additional fragment in the same row if (column_guard) { - accum_max_ = maximum_accumulator_(accum, accum_max_); + accum_max_ = maximum_accumulator_(result, accum_max_); } } - // If this is the last vector in the row, compute the final max and store it out - if (column_idx + 1 == OutputTileIterator::ThreadMap::Iterations::kColumn) { - - float float_max_element = float(accum_max_); - - int thread_row = thread_offset.row() - threadblock_row_; - - // Shared memory atomic operation to partially reduce the maximum element - atomicMax( - reinterpret_cast(shared_storage_.reduction + thread_row), - reinterpret_cast(float_max_element) - ); - - clear_accum_max_(); + CUTLASS_PRAGMA_UNROLL + for (int i = half_thread_in_row; i > 0; i >>= 1) { + ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, accum_max_, i); + accum_max_ = fast_max(accum_max_, tmp); } + SoftmaxFragment sum_frag = exponential(minus(result, accum_max_)); + + ElementSoftmaxCompute reduction_sum = sum_accumulator_(sum_frag); + + CUTLASS_PRAGMA_UNROLL + for (int i = half_thread_in_row; i > 0; i >>= 1) { + ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, reduction_sum, i); + reduction_sum += tmp; + } + + bool is_write_thread = (thread_offset.row() < extent_.row() && (threadIdx.x % thread_in_row) == 0); + ElementNorm *curr_ptr_max = params_.ptr_Max + thread_offset.row() + blockIdx.y * extent_.row(); + ElementSum *curr_ptr_sum = params_.ptr_Sum + thread_offset.row() + blockIdx.y * extent_.row(); + + arch::global_store( + convert_norm_output(accum_max_), + (void *)curr_ptr_max, + is_write_thread); + + arch::global_store( + convert_sum_output(reduction_sum), + (void *)curr_ptr_sum, + is_write_thread); + + clear_accum_max_(); + // Convert to the output - NumericArrayConverter output_converter; + NumericArrayConverter output_converter; OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; output = output_converter(result); } @@ -794,23 +784,6 @@ public: CUTLASS_DEVICE void end_epilogue() { - __syncthreads(); - - int block_batch = blockIdx.z; - int tidx_m = threadblock_row_ + threadIdx.x; - - float float_max_element = shared_storage_.reduction[threadIdx.x]; - - if (tidx_m < extent_.row()) { - - atomicMax( - reinterpret_cast( - params_.ptr_Max + - params_.batch_stride_Max * block_batch + - tidx_m), - reinterpret_cast(float_max_element) - ); - } } private: @@ -819,28 +792,40 @@ private: void clear_accum_max_() { uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX - - accum_max_ = reinterpret_cast(float_max_bits); + float min_float = reinterpret_cast(float_max_bits); + accum_max_ = ElementSoftmaxCompute(min_float); } CUTLASS_DEVICE - float maximum_accumulator_(AccumulatorFragment const &accum) { - ElementAccumulator max_ = accum[0]; + ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) { + ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < AccumulatorFragment::kElements; ++i) { - max_ = fast_max(max_, ElementAccumulator(accum[i])); + for (int i = 0; i < SoftmaxFragment::kElements; ++i) { + sum_ += ElementSoftmaxCompute(accum[i]); + } + + return sum_; + } + + CUTLASS_DEVICE + ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) { + ElementSoftmaxCompute max_ = accum[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < SoftmaxFragment::kElements; ++i) { + max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); } return max_; } CUTLASS_DEVICE - ElementAccumulator maximum_accumulator_(AccumulatorFragment const &accum, ElementAccumulator max_) { + ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) { CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < AccumulatorFragment::kElements; ++i) { - max_ = fast_max(max_, ElementAccumulator(accum[i])); + for (int i = 0; i < SoftmaxFragment::kElements; ++i) { + max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); } return max_; @@ -861,8 +846,10 @@ template < typename LayoutB_, typename ElementC_, typename ElementCompute_, + typename EpilogueFunctorOp_, + typename ElementNorm_ = float, + typename ElementSum_ = float, int Alignment = 128 / cutlass::sizeof_bits::value, - typename ElementSum_ = ElementCompute_, typename ElementSoftmax_ = ElementC_ > class GemmSoftmax { @@ -880,36 +867,27 @@ public: using ElementCompute = ElementCompute_; using ElementSum = ElementSum_; using ElementSoft = ElementSoftmax_; + using ElementSoftmaxCompute = float; using LayoutA = LayoutA_; using LayoutB = LayoutB_; static int const kAlignment = Alignment; - /////////////////////////////////////////////////////////////////////////////////////////////// - - /// Linear scaling operator - using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< - ElementC, - 128 / cutlass::sizeof_bits::value, - ElementCompute, - ElementCompute - >; - - /////////////////////////////////////////////////////////////////////////////////////////////// - - // This is a mandatory data type for the atomic reduction in the GEMM epilogue to function. - using ElementN = float; + using EpilogueFunctorOp = EpilogueFunctorOp_; + using ElementNorm = ElementNorm_; // These are mandatory layouts. using LayoutC = cutlass::layout::RowMajor; using LayoutN = cutlass::layout::RowMajor; + using LayoutS = cutlass::layout::RowMajor; using LayoutSoft = cutlass::layout::RowMajor; using TensorRefA = TensorRef; using TensorRefB = TensorRef; using TensorRefC = TensorRef; - using TensorRefN = TensorRef; + using TensorRefN = TensorRef; + using TensorRefSum = TensorRef; using TensorRefSoft = TensorRef; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; @@ -957,6 +935,9 @@ public: DefaultGemmKernel::kThreadCount, typename DefaultGemmKernel::Epilogue::OutputTileIterator, ElementCompute, + ElementNorm, + ElementSum, + ElementSoftmaxCompute, EpilogueFunctorOp >; @@ -976,15 +957,23 @@ public: // Softmax kernel using SoftmaxApplyKernel = kernel::ApplySoftmax< ElementC, - ElementN, + ElementNorm, ElementSum, ElementSoft, + ElementSoftmaxCompute, kAlignment, MatrixShape< 1, 1024 > >; + using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< + ElementNorm, + ElementSum, + ElementSoftmaxCompute, + ThreadblockShape + >; + public: /// Arguments class @@ -992,7 +981,8 @@ public: typename GemmKernel::Arguments gemm; typename SoftmaxApplyKernel::Arguments softmax; - + typename ApplyFinalReductionKernel::Arguments reduction; + cutlass::gemm::GemmCoord extend; // // Methods // @@ -1007,12 +997,14 @@ public: TensorRefC ref_D_, typename EpilogueFunctorOp::Params linear_scaling, TensorRefN ref_N_, + TensorRefSum ref_S_, TensorRefSoft ref_Softmax_, int64_t batch_stride_A_ = 0, int64_t batch_stride_B_ = 0, int64_t batch_stride_C_ = 0, int64_t batch_stride_D_ = 0, int64_t batch_stride_Max_ = 0, + int64_t batch_stride_Sum_ = 0, int64_t batch_stride_Softmax_ = 0 ): gemm( @@ -1028,21 +1020,34 @@ public: ref_C_, ref_D_, ref_N_.data(), + ref_S_.data(), batch_stride_C_, batch_stride_D_, - batch_stride_Max_ + batch_stride_Max_, + batch_stride_Sum_ ) ), + reduction( + MatrixCoord(problem_size.m(), problem_size.n()), + batch_count_, + ref_N_, + ref_S_, + batch_stride_Max_, + batch_stride_Sum_ + ), softmax( MatrixCoord(problem_size.m(), problem_size.n()), batch_count_, ref_D_, ref_N_, + ref_S_, ref_Softmax_, batch_stride_D_, batch_stride_Max_, + batch_stride_Sum_, batch_stride_Softmax_ - ) + ), + extend(problem_size) { } @@ -1052,7 +1057,8 @@ public: typename GemmKernel::Params gemm; typename SoftmaxApplyKernel::Params softmax; - + typename ApplyFinalReductionKernel::Params reduction; + MatrixCoord extend; // // Methods // @@ -1060,7 +1066,9 @@ public: Params(Arguments const &args): gemm(args.gemm), - softmax(args.softmax) + reduction(args.reduction), + softmax(args.softmax), + extend(MatrixCoord(args.extend.m(), args.extend.n())) { } @@ -1114,6 +1122,35 @@ public: return cutlass::Status::kErrorInternal; } + + // + // Launch the ApplyFinalReductionKernel + // + + int threadblock_num_in_column = (params_.extend.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + + if (threadblock_num_in_column > 1) { + int thread_per_block = 128; + int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + } + + dim3 final_reduction_grid(block_per_row); + dim3 final_reduction_block(thread_per_block); + + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + } + // // Launch the SoftmaxApplyKernel // diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index 004b3a7c..2b083d71 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -71,7 +71,7 @@ public: using ElementCompute = ElementCompute_; static int const kCount = Count; - + static const ScaleType::Kind kScale = Scale; using FragmentOutput = Array; using FragmentAccumulator = Array; using ComputeFragment = Array; diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 14b6522c..dd6c0406 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -718,11 +718,11 @@ double fast_exp(double x) { } CUTLASS_HOST_DEVICE -float fast_exp(half_t x) { +half_t fast_exp(half_t x) { #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750) - return ::hexp(x.to_half()); + return (half_t)(::hexp(x.to_half())); #else - return fast_exp(float(x)); + return (half_t)(fast_exp(float(x))); #endif } @@ -908,4 +908,3 @@ T absolute_value(T x) { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// -