Handle MNK Sm90{Row, Col}Reduction problem shapes (#1803)
This commit is contained in:
parent
cc3c29a81a
commit
53668799b2
@ -523,7 +523,8 @@ public:
|
|||||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
||||||
if constexpr (IsAtomic) {
|
if constexpr (IsAtomic) {
|
||||||
auto [M, N, K, L] = problem_shape;
|
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||||
|
auto [M, N, K, L] = problem_shape_mnkl;
|
||||||
Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar);
|
Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar);
|
||||||
if (args.ptr_scalar != nullptr) {
|
if (args.ptr_scalar != nullptr) {
|
||||||
return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter);
|
return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter);
|
||||||
@ -700,7 +701,9 @@ public:
|
|||||||
reduction_buffer = nullptr;
|
reduction_buffer = nullptr;
|
||||||
}
|
}
|
||||||
else if constexpr (FinalReduction) {
|
else if constexpr (FinalReduction) {
|
||||||
auto [M, N, K, L] = problem_shape;
|
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||||
|
auto [M, N, K, L] = problem_shape_mnkl;
|
||||||
|
|
||||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||||
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
||||||
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
||||||
@ -735,7 +738,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t workspace_size = 0;
|
size_t workspace_size = 0;
|
||||||
auto [M, N, K, L] = problem_shape;
|
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||||
|
auto [M, N, K, L] = problem_shape_mnkl;
|
||||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||||
// Increment by size of reduction buffer
|
// Increment by size of reduction buffer
|
||||||
workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
||||||
@ -750,8 +754,9 @@ public:
|
|||||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
||||||
|
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||||
|
auto [M, N, K, L] = problem_shape_mnkl;
|
||||||
if constexpr (IsAtomic) {
|
if constexpr (IsAtomic) {
|
||||||
auto [M, N, K, L] = problem_shape;
|
|
||||||
Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow);
|
Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow);
|
||||||
if (args.ptr_row != nullptr) {
|
if (args.ptr_row != nullptr) {
|
||||||
return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter);
|
return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter);
|
||||||
@ -761,7 +766,6 @@ public:
|
|||||||
else
|
else
|
||||||
#endif
|
#endif
|
||||||
if constexpr (FinalReduction) {
|
if constexpr (FinalReduction) {
|
||||||
auto [M, N, K, L] = problem_shape;
|
|
||||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||||
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
|
||||||
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
||||||
@ -1290,7 +1294,9 @@ public:
|
|||||||
reduction_buffer = nullptr;
|
reduction_buffer = nullptr;
|
||||||
}
|
}
|
||||||
else if constexpr (FinalReduction) {
|
else if constexpr (FinalReduction) {
|
||||||
auto [M, N, K, L] = problem_shape;
|
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||||
|
auto [M, N, K, L] = problem_shape_mnkl;
|
||||||
|
|
||||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||||
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
|
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
|
||||||
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
||||||
@ -1325,7 +1331,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t workspace_size = 0;
|
size_t workspace_size = 0;
|
||||||
auto [M, N, K, L] = problem_shape;
|
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||||
|
auto [M, N, K, L] = problem_shape_mnkl;
|
||||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||||
|
|
||||||
// Increment by size of reduction buffer
|
// Increment by size of reduction buffer
|
||||||
@ -1342,8 +1349,9 @@ public:
|
|||||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
|
||||||
|
auto problem_shape_mnkl = append<4>(problem_shape, 1);
|
||||||
|
auto [M, N, K, L] = problem_shape_mnkl;
|
||||||
if constexpr (IsAtomic) {
|
if constexpr (IsAtomic) {
|
||||||
auto [M, N, K, L] = problem_shape;
|
|
||||||
Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol);
|
Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol);
|
||||||
if (args.ptr_col != nullptr) {
|
if (args.ptr_col != nullptr) {
|
||||||
return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter);
|
return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter);
|
||||||
@ -1353,7 +1361,6 @@ public:
|
|||||||
else
|
else
|
||||||
#endif
|
#endif
|
||||||
if constexpr (FinalReduction) {
|
if constexpr (FinalReduction) {
|
||||||
auto [M, N, K, L] = problem_shape;
|
|
||||||
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
|
||||||
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
|
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
|
||||||
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user