| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							|  |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! \file
 | 
					
						
							|  |  |  |     \brief Tests for grouped GEMM problem visitors | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							|  |  |  | #include <numeric>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "../../common/cutlass_unit_test.h"
 | 
					
						
							|  |  |  | #include "cutlass/cutlass.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/gemm/gemm.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/grouped_problem_visitor.h"
 | 
					
						
							|  |  |  | #include "cutlass/util/device_memory.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace test { | 
					
						
							|  |  |  | namespace gemm { | 
					
						
							|  |  |  | namespace device { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Use simple problem visitor as a baseline
 | 
					
						
							|  |  |  | template <typename ProblemSizeHelper, | 
					
						
							|  |  |  |           typename ThreadblockShape, | 
					
						
							|  |  |  |           int PrefetchTileCount, | 
					
						
							|  |  |  |           int ThreadCount> | 
					
						
							|  |  |  | struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape> { | 
					
						
							|  |  |  |   using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape>; | 
					
						
							|  |  |  |   using Params = typename Base::Params; | 
					
						
							|  |  |  |   static int const kThreadCount = ThreadCount; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct SharedStorage {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int32_t tile_count_sum; | 
					
						
							|  |  |  |   SharedStorage &shared_storage; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Methods
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE | 
					
						
							|  |  |  |   BaselineProblemVisitor( | 
					
						
							|  |  |  |     Params const ¶ms_, | 
					
						
							|  |  |  |     SharedStorage &shared_storage_, | 
					
						
							|  |  |  |     int32_t block_idx | 
					
						
							|  |  |  |   ): Base(params_, block_idx), | 
					
						
							|  |  |  |   shared_storage(shared_storage_) | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord problem = this->problem_size(); | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord  grid = this->grid_shape(problem); | 
					
						
							|  |  |  |     tile_count_sum = this->tile_count(grid); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE | 
					
						
							|  |  |  |   bool next_tile() { | 
					
						
							|  |  |  |     if (this->tile_idx < tile_count_sum) { | 
					
						
							|  |  |  |       return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     do { | 
					
						
							|  |  |  |       ++this->problem_idx; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (this->problem_idx >= this->params.problem_count) { | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem = this->problem_size(); | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord  grid = this->grid_shape(problem); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       this->problem_tile_start = tile_count_sum; | 
					
						
							|  |  |  |       tile_count_sum += this->tile_count(grid); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     } while (tile_count_sum <= this->tile_idx); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, | 
					
						
							|  |  |  |                                    int32_t problem_count, | 
					
						
							|  |  |  |                                    int32_t block_count) { | 
					
						
							|  |  |  |     return 0; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, | 
					
						
							|  |  |  |                               int32_t problem_count, | 
					
						
							|  |  |  |                               int32_t block_count, | 
					
						
							|  |  |  |                               void* host_workspace_ptr) {} | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename ProblemVisitor> | 
					
						
							|  |  |  | struct ProblemVisitorKernel { | 
					
						
							|  |  |  |   struct SharedStorage { | 
					
						
							|  |  |  |     typename ProblemVisitor::SharedStorage problem_visitor; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct Params { | 
					
						
							|  |  |  |     typename ProblemVisitor::Params problem_visitor_params; | 
					
						
							|  |  |  |     int32_t* visited_problems_ptr; | 
					
						
							|  |  |  |     int32_t* visited_tiles_ptr; | 
					
						
							|  |  |  |     int32_t visits_per_block; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Params(): | 
					
						
							|  |  |  |       visited_problems_ptr(nullptr), | 
					
						
							|  |  |  |       visited_tiles_ptr(nullptr), | 
					
						
							|  |  |  |       visits_per_block(0) {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Params(typename ProblemVisitor::Params problem_visitor_params_, | 
					
						
							|  |  |  |            int32_t* visited_problems_ptr_, | 
					
						
							|  |  |  |            int32_t* visited_tiles_ptr_, | 
					
						
							|  |  |  |            int32_t visits_per_block_): | 
					
						
							|  |  |  |       problem_visitor_params(problem_visitor_params_), | 
					
						
							|  |  |  |       visited_problems_ptr(visited_problems_ptr_), | 
					
						
							|  |  |  |       visited_tiles_ptr(visited_tiles_ptr_), | 
					
						
							|  |  |  |       visits_per_block(visits_per_block_) {} | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE | 
					
						
							|  |  |  |   void operator()(const Params& params, SharedStorage &shared_storage) { | 
					
						
							|  |  |  |     int32_t store_offset = params.visits_per_block * blockIdx.x; | 
					
						
							|  |  |  |     ProblemVisitor problem_visitor(params.problem_visitor_params, | 
					
						
							|  |  |  |                                    shared_storage.problem_visitor, | 
					
						
							|  |  |  |                                    blockIdx.x); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     while (problem_visitor.next_tile()) { | 
					
						
							|  |  |  |       int32_t problem_idx = problem_visitor.problem_index(); | 
					
						
							|  |  |  |       int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (threadIdx.x == 0) { | 
					
						
							|  |  |  |         params.visited_problems_ptr[store_offset] = problem_idx; | 
					
						
							|  |  |  |         params.visited_tiles_ptr[store_offset] = threadblock_idx; | 
					
						
							|  |  |  |         ++store_offset; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       problem_visitor.advance(gridDim.x); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename ProblemVisitor> | 
					
						
							|  |  |  | struct ProblemVisitorRunner { | 
					
						
							|  |  |  |   using BaseKernel = ProblemVisitorKernel<ProblemVisitor>; | 
					
						
							|  |  |  |   using Params = typename BaseKernel::Params; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Params params; | 
					
						
							|  |  |  |   std::vector<cutlass::gemm::GemmCoord> host_problem_sizes; | 
					
						
							|  |  |  |   int32_t problem_count; | 
					
						
							|  |  |  |   int32_t threadblock_count; | 
					
						
							|  |  |  |   int32_t visits_per_block; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<int32_t> visited_problems; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<int32_t> visited_tiles; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<uint8_t> workspace; | 
					
						
							|  |  |  |   std::vector<int32_t> host_visited_problems; | 
					
						
							|  |  |  |   std::vector<int32_t> host_visited_tiles; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   ProblemVisitorRunner(const std::vector<cutlass::gemm::GemmCoord>& host_problem_sizes_, | 
					
						
							|  |  |  |                        int32_t threadblock_count_): | 
					
						
							|  |  |  |       host_problem_sizes(host_problem_sizes_), | 
					
						
							|  |  |  |       problem_count(int32_t(host_problem_sizes_.size())), | 
					
						
							|  |  |  |       threadblock_count(threadblock_count_) {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Initializes GEMM state from arguments.
 | 
					
						
							|  |  |  |   cutlass::Status initialize() { | 
					
						
							|  |  |  |     size_t workspace_bytes = ProblemVisitor::get_workspace_size( | 
					
						
							|  |  |  |                                 host_problem_sizes.data(), | 
					
						
							|  |  |  |                                 problem_count, | 
					
						
							|  |  |  |                                 threadblock_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     workspace.reset(workspace_bytes); | 
					
						
							|  |  |  |     std::vector<uint8_t> host_workspace(workspace_bytes); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, | 
					
						
							|  |  |  |                                     threadblock_count, host_workspace.data()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     workspace.copy_from_host(host_workspace.data(), workspace_bytes); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     device_problem_sizes.reset(problem_count); | 
					
						
							|  |  |  |     device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; | 
					
						
							|  |  |  |     int32_t total_visits = visits_per_block * threadblock_count; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     visited_problems.reset(total_visits); | 
					
						
							|  |  |  |     visited_tiles.reset(total_visits); | 
					
						
							|  |  |  |     host_visited_problems.resize(total_visits); | 
					
						
							|  |  |  |     host_visited_tiles.resize(total_visits); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); | 
					
						
							|  |  |  |     if (result != cudaSuccess) { | 
					
						
							|  |  |  |       return cutlass::Status::kErrorInternal; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); | 
					
						
							|  |  |  |     if (result != cudaSuccess) { | 
					
						
							|  |  |  |       return cutlass::Status::kErrorInternal; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); | 
					
						
							|  |  |  |     params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return cutlass::Status::kSuccess; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool verify() { | 
					
						
							|  |  |  |     // Sort by problem size and then by threadblock_idx
 | 
					
						
							|  |  |  |     std::vector<int32_t> indices(host_visited_problems.size()); | 
					
						
							|  |  |  |     std::iota(indices.begin(), indices.end(), 0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     std::stable_sort(indices.begin(), indices.end(), | 
					
						
							|  |  |  |       [&](int32_t i1, int32_t i2) { | 
					
						
							|  |  |  |         if (host_visited_problems[i1] == host_visited_problems[i2]) { | 
					
						
							|  |  |  |           return host_visited_tiles[i1] < host_visited_tiles[i2]; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return host_visited_problems[i1] < host_visited_problems[i2]; | 
					
						
							|  |  |  |       }); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int32_t idx = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Skip any entries that were not visited
 | 
					
						
							|  |  |  |     while (host_visited_problems[indices[idx]] == -1) { | 
					
						
							|  |  |  |       ++idx; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Check that each problem visited has the tiles we expect
 | 
					
						
							|  |  |  |     for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { | 
					
						
							|  |  |  |       auto problem = host_problem_sizes[problem_idx]; | 
					
						
							|  |  |  |       ProblemVisitor::possibly_transpose_problem(problem); | 
					
						
							|  |  |  |       int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); | 
					
						
							|  |  |  |       for (int i = 0; i < problem_tiles; ++i) { | 
					
						
							|  |  |  |         EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); | 
					
						
							|  |  |  |         EXPECT_EQ(i, host_visited_tiles[indices[idx]]); | 
					
						
							|  |  |  |         ++idx; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool run(cudaStream_t stream = nullptr) { | 
					
						
							|  |  |  |     cutlass::Status status = initialize(); | 
					
						
							|  |  |  |     if (status != cutlass::Status::kSuccess) { | 
					
						
							|  |  |  |       std::cerr << "Initialization failed" << std::endl; | 
					
						
							|  |  |  |       return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     dim3 grid(threadblock_count, 1, 1); | 
					
						
							|  |  |  |     dim3 block(ProblemVisitor::kThreadCount, 1, 1); | 
					
						
							|  |  |  |     int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(params); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cudaError_t result = cudaGetLastError(); | 
					
						
							|  |  |  |     if (result != cudaSuccess) { | 
					
						
							|  |  |  |       std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; | 
					
						
							|  |  |  |       return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     result = cudaDeviceSynchronize(); | 
					
						
							|  |  |  |     if (result != cudaSuccess) { | 
					
						
							|  |  |  |       std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; | 
					
						
							|  |  |  |       return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     visited_problems.copy_to_host(host_visited_problems.data()); | 
					
						
							|  |  |  |     visited_tiles.copy_to_host(host_visited_tiles.data()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return verify(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename ThreadblockShape, | 
					
						
							|  |  |  |           int PrefetchTileCount, | 
					
						
							|  |  |  |           int ThreadCount, | 
					
						
							|  |  |  |           bool Transpose, | 
					
						
							|  |  |  |           cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0, | 
					
						
							|  |  |  |           cutlass::gemm::kernel::GroupScheduleMode... Args> | 
					
						
							|  |  |  | struct TestbedGroupedGemmScheduler { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  |   using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transpose>; | 
					
						
							|  |  |  |   using BaselinePV = BaselineProblemVisitor<PSHelper, | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |                                             ThreadblockShape, | 
					
						
							|  |  |  |                                             PrefetchTileCount, | 
					
						
							|  |  |  |                                             ThreadCount>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Data members
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   uint32_t seed; | 
					
						
							|  |  |  |   int problem_count; | 
					
						
							|  |  |  |   int threadblock_count; | 
					
						
							|  |  |  |   std::vector<cutlass::gemm::GemmCoord> problem_sizes_host; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Methods
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): | 
					
						
							|  |  |  |     seed(seed_) { srand(seed); } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Initializes data structures
 | 
					
						
							|  |  |  |   void initialize(int32_t scale_factor) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Choose random problem sizes
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     problem_sizes_host.clear(); | 
					
						
							|  |  |  |     problem_sizes_host.resize(problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int32_t i = 0; i < problem_count; ++i) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem( | 
					
						
							|  |  |  |         scale_factor * (rand() % 64) + 24, | 
					
						
							|  |  |  |         scale_factor * (rand() % 64) + 24, | 
					
						
							|  |  |  |         scale_factor * (rand() % 64) + 24); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       problem_sizes_host.at(i) = problem; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_> | 
					
						
							|  |  |  |   void compare_visitors(const ProblemVisitorRunner<BaselinePV>& baseline_runner) { | 
					
						
							|  |  |  |     using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< | 
					
						
							|  |  |  |                                          ThreadblockShape, | 
					
						
							|  |  |  |                                          GroupScheduleMode_, | 
					
						
							|  |  |  |                                          PrefetchTileCount, | 
					
						
							|  |  |  |                                          ThreadCount, | 
					
						
							|  |  |  |                                          Transpose>; | 
					
						
							|  |  |  |     ProblemVisitorRunner<PV> runner(problem_sizes_host, threadblock_count); | 
					
						
							|  |  |  |     EXPECT_TRUE(runner.run()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Check that this problem visitor visits the same problems and tiles as the baseline
 | 
					
						
							|  |  |  |     EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); | 
					
						
							|  |  |  |     EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_, | 
					
						
							|  |  |  |             cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode2_, | 
					
						
							|  |  |  |             cutlass::gemm::kernel::GroupScheduleMode... Rest> | 
					
						
							|  |  |  |   void compare_visitors(const ProblemVisitorRunner<BaselinePV>& baseline_runner) { | 
					
						
							|  |  |  |     // Compare the next visitor with the baseline visitor
 | 
					
						
							|  |  |  |     compare_visitors<GroupScheduleMode1_>(baseline_runner); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Recurse to compare the next visitors
 | 
					
						
							|  |  |  |     compare_visitors<GroupScheduleMode2_, Rest...>(baseline_runner); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Executes the test on all scheduler modes
 | 
					
						
							|  |  |  |   void run(int problem_count, int threadblock_count, int scale_factor=8) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     this->problem_count = problem_count; | 
					
						
							|  |  |  |     this->threadblock_count = threadblock_count; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Initialize the problem
 | 
					
						
							|  |  |  |     initialize(scale_factor); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Run the baseline visitor to which we will compare all other visitors
 | 
					
						
							|  |  |  |     ProblemVisitorRunner<BaselinePV> baseline_runner(problem_sizes_host, threadblock_count); | 
					
						
							|  |  |  |     EXPECT_TRUE(baseline_runner.run()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     compare_visitors<Args...>(baseline_runner); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // device
 | 
					
						
							|  |  |  | } // gemm
 | 
					
						
							|  |  |  | } // test
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 |