upstream internal updates (#616)
Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
parent
b72cbf957d
commit
b1d3f9b2fd
@ -218,6 +218,12 @@ struct Testbed {
|
|||||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
using ArchTag = cutlass::arch::Sm80;
|
using ArchTag = cutlass::arch::Sm80;
|
||||||
|
|
||||||
|
// ApplyShape impacts the final Softmax performance a lot.
|
||||||
|
// Set ApplyShape::kColumn to be the next multiple of 32 number that is after
|
||||||
|
// (gemm_N / alignment).
|
||||||
|
// Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn).
|
||||||
|
using ApplyShape = cutlass::MatrixShape<1, 1024>;
|
||||||
|
|
||||||
static int const kStages = 3;
|
static int const kStages = 3;
|
||||||
|
|
||||||
/// Linear scaling operator
|
/// Linear scaling operator
|
||||||
@ -239,7 +245,8 @@ struct Testbed {
|
|||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueFunctorOp,
|
EpilogueFunctorOp,
|
||||||
kStages
|
kStages,
|
||||||
|
ApplyShape
|
||||||
>;
|
>;
|
||||||
|
|
||||||
using ElementNorm = typename GemmSoftmax::ElementNorm;
|
using ElementNorm = typename GemmSoftmax::ElementNorm;
|
||||||
@ -710,6 +717,4 @@ int main(int argc, const char **argv) {
|
|||||||
return (disposition == Disposition::kPassed ? 0 : -1);
|
return (disposition == Disposition::kPassed ? 0 : -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ template <
|
|||||||
typename ElementSoft_,
|
typename ElementSoft_,
|
||||||
typename ElementSoftmaxCompute_,
|
typename ElementSoftmaxCompute_,
|
||||||
int Alignment,
|
int Alignment,
|
||||||
typename Shape_ = MatrixShape<4, 16>
|
typename ApplyShape_ = MatrixShape<1, 1024>
|
||||||
>
|
>
|
||||||
class ApplySoftmax {
|
class ApplySoftmax {
|
||||||
public:
|
public:
|
||||||
@ -91,7 +91,7 @@ public:
|
|||||||
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
|
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
|
||||||
|
|
||||||
static int const kAlignment = Alignment;
|
static int const kAlignment = Alignment;
|
||||||
using Shape = Shape_;
|
using ApplyShape = ApplyShape_;
|
||||||
|
|
||||||
using Layout = cutlass::layout::RowMajor;
|
using Layout = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
@ -202,7 +202,7 @@ private:
|
|||||||
using AccessTypeD = AlignedArray<ElementD, kAlignment>;
|
using AccessTypeD = AlignedArray<ElementD, kAlignment>;
|
||||||
|
|
||||||
int block_batch = blockIdx.z;
|
int block_batch = blockIdx.z;
|
||||||
int block_m = blockIdx.x * Shape::kRow;
|
int block_m = blockIdx.x * ApplyShape::kRow;
|
||||||
int block_n = 0;
|
int block_n = 0;
|
||||||
|
|
||||||
int thread_m = threadIdx.y;
|
int thread_m = threadIdx.y;
|
||||||
@ -256,8 +256,8 @@ private:
|
|||||||
params.args.batch_stride_Soft * block_batch +
|
params.args.batch_stride_Soft * block_batch +
|
||||||
params.args.ref_Soft.layout()({idx_m, idx_n}));
|
params.args.ref_Soft.layout()({idx_m, idx_n}));
|
||||||
|
|
||||||
ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum];
|
ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum];
|
||||||
ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm];
|
ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm];
|
||||||
|
|
||||||
//
|
//
|
||||||
// Loop
|
// Loop
|
||||||
@ -266,10 +266,9 @@ private:
|
|||||||
for (
|
for (
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
idx < params.args.extent.column();
|
idx < params.args.extent.column();
|
||||||
idx += Shape::kColumn * kAlignment) {
|
idx += ApplyShape::kColumn * kAlignment) {
|
||||||
|
|
||||||
if (idx_n < params.args.extent.column()) {
|
if (idx_n < params.args.extent.column()) {
|
||||||
|
|
||||||
AccessTypeD fetch;
|
AccessTypeD fetch;
|
||||||
arch::global_load<AccessTypeD, sizeof(AccessTypeD)>(fetch, access_d, true);
|
arch::global_load<AccessTypeD, sizeof(AccessTypeD)>(fetch, access_d, true);
|
||||||
|
|
||||||
@ -279,14 +278,13 @@ private:
|
|||||||
arch::global_store<FragmentSoft, sizeof(FragmentSoft)>(soft, access_soft, true);
|
arch::global_store<FragmentSoft, sizeof(FragmentSoft)>(soft, access_soft, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
access_d += Shape::kColumn;
|
access_d += ApplyShape::kColumn;
|
||||||
access_soft += Shape::kColumn;
|
access_soft += ApplyShape::kColumn;
|
||||||
idx_n += Shape::kColumn * kAlignment;
|
idx_n += ApplyShape::kColumn * kAlignment;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
@ -308,6 +306,7 @@ template <
|
|||||||
typename InstructionShape_,
|
typename InstructionShape_,
|
||||||
typename EpilogueFunctorOp_,
|
typename EpilogueFunctorOp_,
|
||||||
int kStages_,
|
int kStages_,
|
||||||
|
typename ApplyShape_ = MatrixShape<1, 1024>,
|
||||||
int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value,
|
int AlignmentA_ = 128 / cutlass::sizeof_bits<ElementA_>::value,
|
||||||
int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value,
|
int AlignmentB_ = 128 / cutlass::sizeof_bits<ElementB_>::value,
|
||||||
int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value,
|
int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits<ElementC_>::value,
|
||||||
@ -338,6 +337,8 @@ public:
|
|||||||
using EpilogueFunctorOp = EpilogueFunctorOp_;
|
using EpilogueFunctorOp = EpilogueFunctorOp_;
|
||||||
using ElementNorm = ElementNorm_;
|
using ElementNorm = ElementNorm_;
|
||||||
|
|
||||||
|
using ApplyShape = ApplyShape_;
|
||||||
|
|
||||||
// These are mandatory layouts.
|
// These are mandatory layouts.
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
using LayoutN = cutlass::layout::RowMajor;
|
using LayoutN = cutlass::layout::RowMajor;
|
||||||
@ -427,9 +428,7 @@ public:
|
|||||||
ElementSoft,
|
ElementSoft,
|
||||||
ElementSoftmaxCompute,
|
ElementSoftmaxCompute,
|
||||||
AlignmentSoftmax,
|
AlignmentSoftmax,
|
||||||
MatrixShape<
|
ApplyShape
|
||||||
1, 1024
|
|
||||||
>
|
|
||||||
>;
|
>;
|
||||||
|
|
||||||
using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
|
using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
|
||||||
@ -616,14 +615,14 @@ public:
|
|||||||
// Launch the SoftmaxApplyKernel
|
// Launch the SoftmaxApplyKernel
|
||||||
//
|
//
|
||||||
|
|
||||||
dim3 apply_block(SoftmaxApplyKernel::Shape::kColumn, SoftmaxApplyKernel::Shape::kRow);
|
dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow);
|
||||||
|
|
||||||
int cta_rows = SoftmaxApplyKernel::Shape::kRow;
|
int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow;
|
||||||
int cta_columns = SoftmaxApplyKernel::Shape::kColumn * SoftmaxApplyKernel::kAlignment;
|
int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment;
|
||||||
|
|
||||||
dim3 apply_grid(
|
dim3 apply_grid(
|
||||||
(params_.softmax.args.extent.row() + cta_rows - 1) / cta_rows,
|
(params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows,
|
||||||
(params_.softmax.args.extent.column() + cta_columns - 1) / cta_columns,
|
(params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns,
|
||||||
params_.softmax.args.batch_count);
|
params_.softmax.args.batch_count);
|
||||||
|
|
||||||
Kernel<SoftmaxApplyKernel><<<
|
Kernel<SoftmaxApplyKernel><<<
|
||||||
|
Loading…
Reference in New Issue
Block a user