upstream internal updates (#616)

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai 2022-09-04 20:05:09 -07:00 committed by GitHub
parent b72cbf957d
commit b1d3f9b2fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 22 deletions

View File

@ -218,6 +218,12 @@ struct Testbed {
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
using ArchTag = cutlass::arch::Sm80; using ArchTag = cutlass::arch::Sm80;
// ApplyShape impacts the final Softmax performance a lot.
// Set ApplyShape::kColumn to be the next multiple of 32 number that is after
// (gemm_N / alignment).
// Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn).
using ApplyShape = cutlass::MatrixShape<1, 1024>;
static int const kStages = 3; static int const kStages = 3;
/// Linear scaling operator /// Linear scaling operator
@ -239,7 +245,8 @@ struct Testbed {
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueFunctorOp, EpilogueFunctorOp,
kStages kStages,
ApplyShape
>; >;
using ElementNorm = typename GemmSoftmax::ElementNorm; using ElementNorm = typename GemmSoftmax::ElementNorm;
@ -710,6 +717,4 @@ int main(int argc, const char **argv) {
return (disposition == Disposition::kPassed ? 0 : -1); return (disposition == Disposition::kPassed ? 0 : -1);
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -79,7 +79,7 @@ template <
typename ElementSoft_, typename ElementSoft_,
typename ElementSoftmaxCompute_, typename ElementSoftmaxCompute_,
int Alignment, int Alignment,
typename Shape_ = MatrixShape<4, 16> typename ApplyShape_ = MatrixShape<1, 1024>
> >
class ApplySoftmax { class ApplySoftmax {
public: public:
@ -91,7 +91,7 @@ public:
using ElementSoftmaxCompute = ElementSoftmaxCompute_; using ElementSoftmaxCompute = ElementSoftmaxCompute_;
static int const kAlignment = Alignment; static int const kAlignment = Alignment;
using Shape = Shape_; using ApplyShape = ApplyShape_;
using Layout = cutlass::layout::RowMajor; using Layout = cutlass::layout::RowMajor;
@ -202,7 +202,7 @@ private:
using AccessTypeD = AlignedArray<ElementD, kAlignment>; using AccessTypeD = AlignedArray<ElementD, kAlignment>;
int block_batch = blockIdx.z; int block_batch = blockIdx.z;
int block_m = blockIdx.x * Shape::kRow; int block_m = blockIdx.x * ApplyShape::kRow;
int block_n = 0; int block_n = 0;
int thread_m = threadIdx.y; int thread_m = threadIdx.y;
@ -256,8 +256,8 @@ private:
params.args.batch_stride_Soft * block_batch + params.args.batch_stride_Soft * block_batch +
params.args.ref_Soft.layout()({idx_m, idx_n})); params.args.ref_Soft.layout()({idx_m, idx_n}));
ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum]; ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum];
ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm]; ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm];
// //
// Loop // Loop
@ -266,10 +266,9 @@ private:
for ( for (
int idx = 0; int idx = 0;
idx < params.args.extent.column(); idx < params.args.extent.column();
idx += Shape::kColumn * kAlignment) { idx += ApplyShape::kColumn * kAlignment) {
if (idx_n < params.args.extent.column()) { if (idx_n < params.args.extent.column()) {
AccessTypeD fetch; AccessTypeD fetch;
arch::global_load<AccessTypeD, sizeof(AccessTypeD)>(fetch, access_d, true); arch::global_load<AccessTypeD, sizeof(AccessTypeD)>(fetch, access_d, true);
@ -279,14 +278,13 @@ private:
arch::global_store<FragmentSoft, sizeof(FragmentSoft)>(soft, access_soft, true); arch::global_store<FragmentSoft, sizeof(FragmentSoft)>(soft, access_soft, true);
} }
access_d += Shape::kColumn; access_d += ApplyShape::kColumn;
access_soft += Shape::kColumn; access_soft += ApplyShape::kColumn;
idx_n += Shape::kColumn * kAlignment; idx_n += ApplyShape::kColumn * kAlignment;
} }
} }
}; };
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel } // namespace kernel
@ -308,6 +306,7 @@ template <
typename InstructionShape_, typename InstructionShape_,
typename EpilogueFunctorOp_, typename EpilogueFunctorOp_,
int kStages_, int kStages_,
typename ApplyShape_ = MatrixShape<1, 1024>,
int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value, int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value,
int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value, int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value,
int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value, int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value,
@ -338,6 +337,8 @@ public:
using EpilogueFunctorOp = EpilogueFunctorOp_; using EpilogueFunctorOp = EpilogueFunctorOp_;
using ElementNorm = ElementNorm_; using ElementNorm = ElementNorm_;
using ApplyShape = ApplyShape_;
// These are mandatory layouts. // These are mandatory layouts.
using LayoutC = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor;
using LayoutN = cutlass::layout::RowMajor; using LayoutN = cutlass::layout::RowMajor;
@ -427,9 +428,7 @@ public:
ElementSoft, ElementSoft,
ElementSoftmaxCompute, ElementSoftmaxCompute,
AlignmentSoftmax, AlignmentSoftmax,
MatrixShape< ApplyShape
1, 1024
>
>; >;
using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
@ -616,14 +615,14 @@ public:
// Launch the SoftmaxApplyKernel // Launch the SoftmaxApplyKernel
// //
dim3 apply_block(SoftmaxApplyKernel::Shape::kColumn, SoftmaxApplyKernel::Shape::kRow); dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow);
int cta_rows = SoftmaxApplyKernel::Shape::kRow; int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow;
int cta_columns = SoftmaxApplyKernel::Shape::kColumn * SoftmaxApplyKernel::kAlignment; int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment;
dim3 apply_grid( dim3 apply_grid(
(params_.softmax.args.extent.row() + cta_rows - 1) / cta_rows, (params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows,
(params_.softmax.args.extent.column() + cta_columns - 1) / cta_columns, (params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns,
params_.softmax.args.batch_count); params_.softmax.args.batch_count);
Kernel<SoftmaxApplyKernel><<< Kernel<SoftmaxApplyKernel><<<