Handle MNK Sm90{Row, Col}Reduction problem shapes (#1803)

This commit is contained in:
Saagar Jha 2024-10-14 16:46:20 -07:00 committed by GitHub
parent cc3c29a81a
commit 53668799b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -523,7 +523,8 @@ public:
CudaHostAdapter* cuda_adapter = nullptr) {
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
if constexpr (IsAtomic) {
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar);
if (args.ptr_scalar != nullptr) {
return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter);
@ -700,7 +701,9 @@ public:
reduction_buffer = nullptr;
}
else if constexpr (FinalReduction) {
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
@ -735,7 +738,8 @@ public:
}
size_t workspace_size = 0;
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
// Increment by size of reduction buffer
workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
@ -750,8 +754,9 @@ public:
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
if constexpr (IsAtomic) {
auto [M, N, K, L] = problem_shape;
Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow);
if (args.ptr_row != nullptr) {
return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter);
@ -761,7 +766,6 @@ public:
else
#endif
if constexpr (FinalReduction) {
auto [M, N, K, L] = problem_shape;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
@ -1290,7 +1294,9 @@ public:
reduction_buffer = nullptr;
}
else if constexpr (FinalReduction) {
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
@ -1325,7 +1331,8 @@ public:
}
size_t workspace_size = 0;
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
// Increment by size of reduction buffer
@ -1342,8 +1349,9 @@ public:
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
if constexpr (IsAtomic) {
auto [M, N, K, L] = problem_shape;
Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol);
if (args.ptr_col != nullptr) {
return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter);
@ -1353,7 +1361,6 @@ public:
else
#endif
if constexpr (FinalReduction) {
auto [M, N, K, L] = problem_shape;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);