diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 55b87b3c..afa8c329 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -218,6 +218,12 @@ struct Testbed { using OperatorClass = cutlass::arch::OpClassTensorOp; using ArchTag = cutlass::arch::Sm80; + // ApplyShape impacts the final Softmax performance a lot. + // Set ApplyShape::kColumn to be the next multiple of 32 number that is after + // (gemm_N / alignment). + // Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn). + using ApplyShape = cutlass::MatrixShape<1, 1024>; + static int const kStages = 3; /// Linear scaling operator @@ -239,7 +245,8 @@ struct Testbed { WarpShape, InstructionShape, EpilogueFunctorOp, - kStages + kStages, + ApplyShape >; using ElementNorm = typename GemmSoftmax::ElementNorm; @@ -710,6 +717,4 @@ int main(int argc, const char **argv) { return (disposition == Disposition::kPassed ? 0 : -1); } - ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/examples/35_gemm_softmax/gemm_with_softmax.h b/examples/35_gemm_softmax/gemm_with_softmax.h index 37f7a746..ed3caadd 100644 --- a/examples/35_gemm_softmax/gemm_with_softmax.h +++ b/examples/35_gemm_softmax/gemm_with_softmax.h @@ -79,7 +79,7 @@ template < typename ElementSoft_, typename ElementSoftmaxCompute_, int Alignment, - typename Shape_ = MatrixShape<4, 16> + typename ApplyShape_ = MatrixShape<1, 1024> > class ApplySoftmax { public: @@ -91,7 +91,7 @@ public: using ElementSoftmaxCompute = ElementSoftmaxCompute_; static int const kAlignment = Alignment; - using Shape = Shape_; + using ApplyShape = ApplyShape_; using Layout = cutlass::layout::RowMajor; @@ -202,7 +202,7 @@ private: using AccessTypeD = AlignedArray; int block_batch = blockIdx.z; - int block_m = blockIdx.x * Shape::kRow; + int block_m = blockIdx.x * ApplyShape::kRow; int block_n = 0; int thread_m = threadIdx.y; @@ -256,8 +256,8 @@ private: params.args.batch_stride_Soft * block_batch + params.args.ref_Soft.layout()({idx_m, idx_n})); - ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum]; - ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm]; + ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum]; + ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm]; // // Loop @@ -266,10 +266,9 @@ private: for ( int idx = 0; idx < params.args.extent.column(); - idx += Shape::kColumn * kAlignment) { + idx += ApplyShape::kColumn * kAlignment) { if (idx_n < params.args.extent.column()) { - AccessTypeD fetch; arch::global_load(fetch, access_d, true); @@ -279,14 +278,13 @@ private: arch::global_store(soft, access_soft, true); } - access_d += Shape::kColumn; - access_soft += Shape::kColumn; - idx_n += Shape::kColumn * kAlignment; + access_d += ApplyShape::kColumn; + access_soft += ApplyShape::kColumn; + idx_n += ApplyShape::kColumn * kAlignment; } } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel @@ -308,6 +306,7 @@ template < typename InstructionShape_, typename EpilogueFunctorOp_, int kStages_, + typename ApplyShape_ = MatrixShape<1, 1024>, int AlignmentA_ = 128 / cutlass::sizeof_bits::value, int AlignmentB_ = 128 / cutlass::sizeof_bits::value, int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits::value, @@ -338,6 +337,8 @@ public: using EpilogueFunctorOp = EpilogueFunctorOp_; using ElementNorm = ElementNorm_; + using ApplyShape = ApplyShape_; + // These are mandatory layouts. using LayoutC = cutlass::layout::RowMajor; using LayoutN = cutlass::layout::RowMajor; @@ -427,9 +428,7 @@ public: ElementSoft, ElementSoftmaxCompute, AlignmentSoftmax, - MatrixShape< - 1, 1024 - > + ApplyShape >; using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< @@ -616,14 +615,14 @@ public: // Launch the SoftmaxApplyKernel // - dim3 apply_block(SoftmaxApplyKernel::Shape::kColumn, SoftmaxApplyKernel::Shape::kRow); + dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow); - int cta_rows = SoftmaxApplyKernel::Shape::kRow; - int cta_columns = SoftmaxApplyKernel::Shape::kColumn * SoftmaxApplyKernel::kAlignment; + int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow; + int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment; dim3 apply_grid( - (params_.softmax.args.extent.row() + cta_rows - 1) / cta_rows, - (params_.softmax.args.extent.column() + cta_columns - 1) / cta_columns, + (params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows, + (params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns, params_.softmax.args.batch_count); Kernel<<<