upstream internal updates (#616)
Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
parent
b72cbf957d
commit
b1d3f9b2fd
@ -218,6 +218,12 @@ struct Testbed {
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using ArchTag = cutlass::arch::Sm80;
|
||||
|
||||
// ApplyShape impacts the final Softmax performance a lot.
|
||||
// Set ApplyShape::kColumn to be the next multiple of 32 number that is after
|
||||
// (gemm_N / alignment).
|
||||
// Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn).
|
||||
using ApplyShape = cutlass::MatrixShape<1, 1024>;
|
||||
|
||||
static int const kStages = 3;
|
||||
|
||||
/// Linear scaling operator
|
||||
@ -239,7 +245,8 @@ struct Testbed {
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueFunctorOp,
|
||||
kStages
|
||||
kStages,
|
||||
ApplyShape
|
||||
>;
|
||||
|
||||
using ElementNorm = typename GemmSoftmax::ElementNorm;
|
||||
@ -710,6 +717,4 @@ int main(int argc, const char **argv) {
|
||||
return (disposition == Disposition::kPassed ? 0 : -1);
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
@ -79,7 +79,7 @@ template <
|
||||
typename ElementSoft_,
|
||||
typename ElementSoftmaxCompute_,
|
||||
int Alignment,
|
||||
typename Shape_ = MatrixShape<4, 16>
|
||||
typename ApplyShape_ = MatrixShape<1, 1024>
|
||||
>
|
||||
class ApplySoftmax {
|
||||
public:
|
||||
@ -91,7 +91,7 @@ public:
|
||||
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
|
||||
|
||||
static int const kAlignment = Alignment;
|
||||
using Shape = Shape_;
|
||||
using ApplyShape = ApplyShape_;
|
||||
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
@ -202,7 +202,7 @@ private:
|
||||
using AccessTypeD = AlignedArray<ElementD, kAlignment>;
|
||||
|
||||
int block_batch = blockIdx.z;
|
||||
int block_m = blockIdx.x * Shape::kRow;
|
||||
int block_m = blockIdx.x * ApplyShape::kRow;
|
||||
int block_n = 0;
|
||||
|
||||
int thread_m = threadIdx.y;
|
||||
@ -256,8 +256,8 @@ private:
|
||||
params.args.batch_stride_Soft * block_batch +
|
||||
params.args.ref_Soft.layout()({idx_m, idx_n}));
|
||||
|
||||
ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum];
|
||||
ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm];
|
||||
ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum];
|
||||
ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm];
|
||||
|
||||
//
|
||||
// Loop
|
||||
@ -266,10 +266,9 @@ private:
|
||||
for (
|
||||
int idx = 0;
|
||||
idx < params.args.extent.column();
|
||||
idx += Shape::kColumn * kAlignment) {
|
||||
idx += ApplyShape::kColumn * kAlignment) {
|
||||
|
||||
if (idx_n < params.args.extent.column()) {
|
||||
|
||||
AccessTypeD fetch;
|
||||
arch::global_load<AccessTypeD, sizeof(AccessTypeD)>(fetch, access_d, true);
|
||||
|
||||
@ -279,14 +278,13 @@ private:
|
||||
arch::global_store<FragmentSoft, sizeof(FragmentSoft)>(soft, access_soft, true);
|
||||
}
|
||||
|
||||
access_d += Shape::kColumn;
|
||||
access_soft += Shape::kColumn;
|
||||
idx_n += Shape::kColumn * kAlignment;
|
||||
access_d += ApplyShape::kColumn;
|
||||
access_soft += ApplyShape::kColumn;
|
||||
idx_n += ApplyShape::kColumn * kAlignment;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
@ -308,6 +306,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
typename EpilogueFunctorOp_,
|
||||
int kStages_,
|
||||
typename ApplyShape_ = MatrixShape<1, 1024>,
|
||||
int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value,
|
||||
int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value,
|
||||
int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value,
|
||||
@ -338,6 +337,8 @@ public:
|
||||
using EpilogueFunctorOp = EpilogueFunctorOp_;
|
||||
using ElementNorm = ElementNorm_;
|
||||
|
||||
using ApplyShape = ApplyShape_;
|
||||
|
||||
// These are mandatory layouts.
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutN = cutlass::layout::RowMajor;
|
||||
@ -427,9 +428,7 @@ public:
|
||||
ElementSoft,
|
||||
ElementSoftmaxCompute,
|
||||
AlignmentSoftmax,
|
||||
MatrixShape<
|
||||
1, 1024
|
||||
>
|
||||
ApplyShape
|
||||
>;
|
||||
|
||||
using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
|
||||
@ -616,14 +615,14 @@ public:
|
||||
// Launch the SoftmaxApplyKernel
|
||||
//
|
||||
|
||||
dim3 apply_block(SoftmaxApplyKernel::Shape::kColumn, SoftmaxApplyKernel::Shape::kRow);
|
||||
dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow);
|
||||
|
||||
int cta_rows = SoftmaxApplyKernel::Shape::kRow;
|
||||
int cta_columns = SoftmaxApplyKernel::Shape::kColumn * SoftmaxApplyKernel::kAlignment;
|
||||
int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow;
|
||||
int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment;
|
||||
|
||||
dim3 apply_grid(
|
||||
(params_.softmax.args.extent.row() + cta_rows - 1) / cta_rows,
|
||||
(params_.softmax.args.extent.column() + cta_columns - 1) / cta_columns,
|
||||
(params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows,
|
||||
(params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns,
|
||||
params_.softmax.args.batch_count);
|
||||
|
||||
Kernel<SoftmaxApplyKernel><<<
|
||||
|
Loading…
Reference in New Issue
Block a user