Handle MNK Sm90{Row, Col}Reduction problem shapes (#1803)
This commit is contained in:
parent
cc3c29a81a
commit
53668799b2
@ -523,7 +523,8 @@ public:
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
||||
if constexpr (IsAtomic) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar);
|
||||
if (args.ptr_scalar != nullptr) {
|
||||
return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter);
|
||||
@ -700,7 +701,9 @@ public:
|
||||
reduction_buffer = nullptr;
|
||||
}
|
||||
else if constexpr (FinalReduction) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
|
||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
||||
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
||||
@ -735,7 +738,8 @@ public:
|
||||
}
|
||||
|
||||
size_t workspace_size = 0;
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||
// Increment by size of reduction buffer
|
||||
workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
||||
@ -750,8 +754,9 @@ public:
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
||||
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
if constexpr (IsAtomic) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow);
|
||||
if (args.ptr_row != nullptr) {
|
||||
return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter);
|
||||
@ -761,7 +766,6 @@ public:
|
||||
else
|
||||
#endif
|
||||
if constexpr (FinalReduction) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
||||
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
||||
@ -1290,7 +1294,9 @@ public:
|
||||
reduction_buffer = nullptr;
|
||||
}
|
||||
else if constexpr (FinalReduction) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
|
||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
|
||||
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
||||
@ -1325,7 +1331,8 @@ public:
|
||||
}
|
||||
|
||||
size_t workspace_size = 0;
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||
|
||||
// Increment by size of reduction buffer
|
||||
@ -1342,8 +1349,9 @@ public:
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
||||
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
if constexpr (IsAtomic) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol);
|
||||
if (args.ptr_col != nullptr) {
|
||||
return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter);
|
||||
@ -1353,7 +1361,6 @@ public:
|
||||
else
|
||||
#endif
|
||||
if constexpr (FinalReduction) {
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
|
||||
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user