upstream internal updates (#616)

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai 2022-09-04 20:05:09 -07:00 committed by GitHub
parent b72cbf957d
commit b1d3f9b2fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 22 deletions

View File

@ -218,6 +218,12 @@ struct Testbed {
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ArchTag = cutlass::arch::Sm80;
// ApplyShape impacts the final Softmax performance a lot.
// Set ApplyShape::kColumn to be the next multiple of 32 number that is after
// (gemm_N / alignment).
// Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn).
using ApplyShape = cutlass::MatrixShape<1, 1024>;
static int const kStages = 3;
/// Linear scaling operator
@ -239,7 +245,8 @@ struct Testbed {
WarpShape,
InstructionShape,
EpilogueFunctorOp,
kStages
kStages,
ApplyShape
>;
using ElementNorm = typename GemmSoftmax::ElementNorm;
@ -710,6 +717,4 @@ int main(int argc, const char **argv) {
return (disposition == Disposition::kPassed ? 0 : -1);
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -79,7 +79,7 @@ template <
typename ElementSoft_,
typename ElementSoftmaxCompute_,
int Alignment,
typename Shape_ = MatrixShape<4, 16>
typename ApplyShape_ = MatrixShape<1, 1024>
>
class ApplySoftmax {
public:
@ -91,7 +91,7 @@ public:
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
static int const kAlignment = Alignment;
using Shape = Shape_;
using ApplyShape = ApplyShape_;
using Layout = cutlass::layout::RowMajor;
@ -202,7 +202,7 @@ private:
using AccessTypeD = AlignedArray<ElementD, kAlignment>;
int block_batch = blockIdx.z;
int block_m = blockIdx.x * Shape::kRow;
int block_m = blockIdx.x * ApplyShape::kRow;
int block_n = 0;
int thread_m = threadIdx.y;
@ -256,8 +256,8 @@ private:
params.args.batch_stride_Soft * block_batch +
params.args.ref_Soft.layout()({idx_m, idx_n}));
ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum];
ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm];
ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum];
ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm];
//
// Loop
@ -266,10 +266,9 @@ private:
for (
int idx = 0;
idx < params.args.extent.column();
idx += Shape::kColumn * kAlignment) {
idx += ApplyShape::kColumn * kAlignment) {
if (idx_n < params.args.extent.column()) {
AccessTypeD fetch;
arch::global_load<AccessTypeD, sizeof(AccessTypeD)>(fetch, access_d, true);
@ -279,14 +278,13 @@ private:
arch::global_store<FragmentSoft, sizeof(FragmentSoft)>(soft, access_soft, true);
}
access_d += Shape::kColumn;
access_soft += Shape::kColumn;
idx_n += Shape::kColumn * kAlignment;
access_d += ApplyShape::kColumn;
access_soft += ApplyShape::kColumn;
idx_n += ApplyShape::kColumn * kAlignment;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
@ -308,6 +306,7 @@ template <
typename InstructionShape_,
typename EpilogueFunctorOp_,
int kStages_,
typename ApplyShape_ = MatrixShape<1, 1024>,
int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value,
int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value,
int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value,
@ -338,6 +337,8 @@ public:
using EpilogueFunctorOp = EpilogueFunctorOp_;
using ElementNorm = ElementNorm_;
using ApplyShape = ApplyShape_;
// These are mandatory layouts.
using LayoutC = cutlass::layout::RowMajor;
using LayoutN = cutlass::layout::RowMajor;
@ -427,9 +428,7 @@ public:
ElementSoft,
ElementSoftmaxCompute,
AlignmentSoftmax,
MatrixShape<
1, 1024
>
ApplyShape
>;
using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
@ -616,14 +615,14 @@ public:
// Launch the SoftmaxApplyKernel
//
dim3 apply_block(SoftmaxApplyKernel::Shape::kColumn, SoftmaxApplyKernel::Shape::kRow);
dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow);
int cta_rows = SoftmaxApplyKernel::Shape::kRow;
int cta_columns = SoftmaxApplyKernel::Shape::kColumn * SoftmaxApplyKernel::kAlignment;
int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow;
int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment;
dim3 apply_grid(
(params_.softmax.args.extent.row() + cta_rows - 1) / cta_rows,
(params_.softmax.args.extent.column() + cta_columns - 1) / cta_columns,
(params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows,
(params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns,
params_.softmax.args.batch_count);
Kernel<SoftmaxApplyKernel><<<