cutlass/include/cutlass/gemm/kernel/params_universal_base.h
2024-03-19 17:51:04 -04:00

265 lines
8.3 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Base functionality for common types of universal GEMM kernel parameters
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/gemm.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace util {
template <class LayoutA, class LayoutB>
CUTLASS_HOST_DEVICE
static bool
is_continous_k_aligned(GemmCoord problem_size, size_t alignmentA, size_t alignmentB) {
return (platform::is_same<LayoutA, layout::RowMajor>::value && (problem_size.k() % alignmentA) == 0) ||
(platform::is_same<LayoutB, layout::ColumnMajor>::value && (problem_size.k() % alignmentB) == 0);
}
} // namespace util
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Argument structure
struct UniversalArgumentsBase
{
//
// Data members
//
GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
GemmCoord problem_size{};
int batch_count{1};
int64_t batch_stride_D{0};
//
// Methods
//
UniversalArgumentsBase() = default;
/// constructs an arguments structure
UniversalArgumentsBase(
GemmUniversalMode mode,
GemmCoord problem_size,
int batch_count,
int64_t batch_stride_D)
:
mode(mode),
problem_size(problem_size),
batch_count(batch_count),
batch_stride_D(batch_stride_D)
{
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}
};
/// Parameters structure
template <
typename ThreadblockSwizzle,
typename ThreadblockShape,
typename ElementA,
typename ElementB,
typename ElementC,
typename LayoutA,
typename LayoutB>
struct UniversalParamsBase
{
//
// Data members
//
GemmCoord problem_size{};
GemmCoord grid_tiled_shape{};
int swizzle_log_tile{0};
GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
int batch_count {0};
int gemm_k_size {0};
int64_t batch_stride_D {0};
int *semaphore = nullptr;
//
// Host dispatch API
//
/// Default constructor
UniversalParamsBase() = default;
/// Constructor
UniversalParamsBase(
UniversalArgumentsBase const &args, /// GEMM application arguments
int device_sms, /// Number of SMs on the device
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
:
problem_size(args.problem_size),
mode(args.mode),
batch_count(args.batch_count),
batch_stride_D(args.batch_stride_D),
semaphore(nullptr)
{
init_grid_tiled_shape();
}
/// Returns the workspace size (in bytes) needed for this problem geometry
size_t get_workspace_size() const
{
size_t workspace_bytes = 0;
if (mode == GemmUniversalMode::kGemmSplitKParallel)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes =
sizeof(ElementC) *
size_t(batch_stride_D) *
size_t(grid_tiled_shape.k());
}
else if (mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
return workspace_bytes;
}
/// Assign and initialize the specified workspace buffer. Assumes
/// the memory allocated to workspace is at least as large as get_workspace_size().
Status init_workspace(
void *workspace,
cudaStream_t stream = nullptr)
{
semaphore = static_cast<int *>(workspace);
// Zero-initialize entire workspace
if (semaphore)
{
size_t workspace_bytes = get_workspace_size();
CUTLASS_TRACE_HOST(" Initialize " << workspace_bytes << " workspace bytes");
cudaError_t result = cudaMemsetAsync(
semaphore,
0,
workspace_bytes,
stream);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Returns the GEMM volume in thread block tiles
GemmCoord get_tiled_shape() const
{
return grid_tiled_shape;
}
/// Returns the total number of thread blocks to launch
int get_grid_blocks() const
{
dim3 grid_dims = get_grid_dims();
return grid_dims.x * grid_dims.y * grid_dims.z;
}
/// Returns the grid extents in thread blocks to launch
dim3 get_grid_dims() const
{
return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape);
}
private:
CUTLASS_HOST_DEVICE
void init_grid_tiled_shape() {
// Get GEMM volume in thread block tiles
grid_tiled_shape = ThreadblockSwizzle::get_tiled_shape(
problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
batch_count);
swizzle_log_tile = ThreadblockSwizzle::get_log_tile(grid_tiled_shape);
// Determine extent of K-dimension assigned to each block
gemm_k_size = problem_size.k();
if (mode == GemmUniversalMode::kGemm || mode == GemmUniversalMode::kGemmSplitKParallel)
{
static const uint32_t CACHELINE_BYTES = 128;
static const size_t element_bytes_a = sizeof(ElementA);
static const size_t element_bytes_b = sizeof(ElementB);
static const size_t cacheline_elements_a = CACHELINE_BYTES / element_bytes_a;
static const size_t cacheline_elements_b = CACHELINE_BYTES / element_bytes_b;
const bool cacheline_alignment_needed =
util::is_continous_k_aligned<LayoutA, LayoutB>(problem_size, cacheline_elements_a, cacheline_elements_b);
int const kAlignK = const_max(
const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value),
cacheline_alignment_needed ? const_max(cacheline_elements_a, cacheline_elements_b) : 1);
gemm_k_size = round_up(ceil_div(problem_size.k(), batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(problem_size.k(), gemm_k_size);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////