CUTLASS 2.6.1 - functional and performance enhancements to strided DGRAD, fixes, and tuning
* cutlass 2.6 update * remove debug prints * cutlass 2.6.1 (minor update) * Updated CHANGELOG. * Minor edit to readme to indicate patch version. * Minor edit to readme. Co-authored-by: Haicheng Wu <haichengw@nvidia.com>, Andrew Kerr <akerr@nvidia.com>
This commit is contained in:
parent
a01feb93d9
commit
6c2f8f2fb8
@ -2,6 +2,12 @@
|
||||
|
||||
# CUTLASS 2.x
|
||||
|
||||
## [2.6.1](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.1) (2021-09-03)
|
||||
* Arbitrary padding and striding for CUTLASS Strided DGRAD Convolution operator (Analytic Iterators)
|
||||
* Tuning for GEMMs fused with partial reductions
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
* Thank you for filing these issues!
|
||||
|
||||
## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22)
|
||||
* Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
* Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h)
|
||||
@ -23,7 +29,8 @@
|
||||
* Many improvements to the epilogue.
|
||||
* Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations
|
||||
* Performance improvement for FP16 tensor core kernels
|
||||
* Bug fixes
|
||||
* Bug fixes
|
||||
* Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
* Updated minimum CUDA Toolkit requirement to 10.2
|
||||
* [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended
|
||||
* Corrections and bug fixes reported by the CUTLASS community
|
||||
|
||||
@ -168,6 +168,11 @@ if (${CUTLASS_NVCC_VERBOSE})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS NAMESPACE
|
||||
#
|
||||
set(CUTLASS_NAMESPACE "cutlass" CACHE STRING "Top level namespace of CUTLASS")
|
||||
|
||||
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
@ -383,6 +388,8 @@ function(cutlass_apply_standard_compile_options TARGET)
|
||||
set(_FLAGS_DEBUG ${__CUTLASS_CUDA_FLAGS_DEBUG} ${__CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE CUTLASS)
|
||||
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
@ -425,6 +432,7 @@ set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/includ
|
||||
include_directories(${CUTLASS_INCLUDE_DIR})
|
||||
|
||||
target_compile_features(CUTLASS INTERFACE cxx_std_11)
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE})
|
||||
|
||||
if (NOT DEFINED CUTLASS_REVISION)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
# CUTLASS 2.6
|
||||
|
||||
_CUTLASS 2.6 - July 2021_
|
||||
_CUTLASS 2.6.1 - September 2021_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
@ -34,6 +34,8 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
See the [functionality listing](/media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
See the [CHANGELOG](CHANGELOG.md) for descriptions of recent updates.
|
||||
|
||||
# What's New in CUTLASS 2.6
|
||||
CUTLASS 2.6 is a minor update to CUTLASS adding:
|
||||
- Fused [broadcast](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) and [reductions](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) in the epilogues of GEMM and Convolution
|
||||
@ -41,11 +43,12 @@ CUTLASS 2.6 is a minor update to CUTLASS adding:
|
||||
- [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation offers up to 4x performance improvements over previous strided Dgrad
|
||||
- 64-bit strides for large tensor allocations
|
||||
- [General affine layouts](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) fp64 tensor core and simt GEMM
|
||||
- [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation
|
||||
- Enhanced functionality, boosted performance, and bug fixes in the epilogue.
|
||||
- Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
- Adopt new L2 prefetch feature in [ptx instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-isa-version-7-4).
|
||||
- Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere.
|
||||
- Numerous updates from the community (thanks!)
|
||||
- See the [CHANGELOG](CHANGELOG.md) for more details
|
||||
|
||||
# What's New in CUTLASS 2.5
|
||||
CUTLASS 2.5 is a minor update to CUTLASS adding:
|
||||
|
||||
@ -390,14 +390,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<B2bGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -197,14 +197,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
cutlass::Kernel<B2bImplicitGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@ -395,10 +395,8 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_0) {
|
||||
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
iterator_A0.set_iteration_index(0);
|
||||
this->smem_iterator_A0_.set_iteration_index(0);
|
||||
@ -490,10 +488,8 @@ public:
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
@ -601,10 +597,8 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations_0;
|
||||
if (gemm_k_iterations_0 == 0) {
|
||||
iterator_A0.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
iterator_A0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
iterator_B0.clear_mask(gemm_k_iterations_0 == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
@ -634,9 +628,7 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations_1) {
|
||||
|
||||
if (gemm_k_iterations_1 == 0) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
iterator_B1.set_iteration_index(0);
|
||||
this->smem_iterator_B1_.set_iteration_index(0);
|
||||
@ -694,9 +686,7 @@ public:
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (gemm_k_iterations_1 == 0) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
|
||||
smem_write_stage_idx = Base::kStages - 1;
|
||||
smem_read_stage_idx = 0;
|
||||
@ -793,9 +783,7 @@ public:
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
if (gemm_k_iterations_1 == 1) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
|
||||
@ -47,6 +47,7 @@
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace conv {
|
||||
@ -485,6 +486,27 @@ int strided_dgrad_tile_m_per_filter(
|
||||
return tile_m_per_filter;
|
||||
}
|
||||
|
||||
// Computes starting Dx coord (h, w) for given starting filter postion
|
||||
CUTLASS_HOST_DEVICE
|
||||
void strided_dgrad_starting_coords(
|
||||
Conv2dProblemSize const &problem_size,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int r, int s,
|
||||
int &start_h, int &start_w) {
|
||||
|
||||
// function locals for remainder by fast divmod
|
||||
int pad_h_rem_, pad_w_rem_;
|
||||
|
||||
// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
|
||||
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
|
||||
int r_ = std::abs(problem_size.stride_h - (pad_h_rem_ - r));
|
||||
stride_h_divmod.divmod(start_h, r_);
|
||||
|
||||
//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
|
||||
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
|
||||
int s_ = std::abs(problem_size.stride_w - (pad_w_rem_ - s));
|
||||
stride_w_divmod.divmod(start_w, s_);
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace cutlass
|
||||
|
||||
@ -217,14 +217,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
cutlass::Kernel<ImplicitGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@ -199,7 +199,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
struct Params {
|
||||
ConvProblemSize problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
FastDivmod filter_s_divmod;
|
||||
FastDivmod stride_h_divmod;
|
||||
FastDivmod stride_w_divmod;
|
||||
int gemm_k_iterations;
|
||||
typename Mma::IteratorA::Params iterator_A;
|
||||
typename Mma::IteratorA::Element const *ptr_A;
|
||||
@ -227,7 +228,8 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
int *semaphore = nullptr
|
||||
):
|
||||
problem_size(args.problem_size),
|
||||
filter_s_divmod(args.problem_size.stride_w),
|
||||
stride_h_divmod(args.problem_size.stride_h),
|
||||
stride_w_divmod(args.problem_size.stride_w),
|
||||
iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
|
||||
ptr_A(args.ref_A.data()),
|
||||
iterator_B(args.problem_size, args.ref_B.layout()),
|
||||
@ -297,7 +299,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
// int start_s = filter_tile_m % (params.problem_size.stride_w);
|
||||
|
||||
int start_r, start_s;
|
||||
params.filter_s_divmod(start_r, start_s, filter_tile_m);
|
||||
params.stride_w_divmod(start_r, start_s, filter_tile_m);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
@ -320,6 +322,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.problem_size,
|
||||
params.ptr_A,
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
MatrixCoord(
|
||||
threadblock_tile_idx.m() * Mma::Shape::kM,
|
||||
@ -386,6 +389,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.ptr_D,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset
|
||||
);
|
||||
@ -396,6 +400,7 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
params.ptr_C,
|
||||
ConvOutputIteratorParameter::extent(params.problem_size),
|
||||
thread_idx,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
@ -130,7 +130,6 @@ private:
|
||||
int offset_p_[ThreadMap::Iterations::kStrided];
|
||||
int offset_q_[ThreadMap::Iterations::kStrided];
|
||||
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -139,6 +138,7 @@ public:
|
||||
Conv2dProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles
|
||||
):
|
||||
@ -164,9 +164,12 @@ public:
|
||||
}
|
||||
|
||||
// Starting h, w positions for filter position in gemm_k=0
|
||||
int start_h = std::abs((problem_size_.pad_h - filter_r) % problem_size_.stride_h);
|
||||
int start_w = std::abs((problem_size_.pad_w - filter_s) % problem_size_.stride_w);
|
||||
|
||||
int start_h, start_w;
|
||||
strided_dgrad_starting_coords(
|
||||
problem_size_,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
filter_r, filter_s,
|
||||
start_h, start_w);
|
||||
|
||||
// Effective P and Q for filter position required for remapping NHW rows
|
||||
int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h;
|
||||
|
||||
@ -200,7 +200,27 @@ private:
|
||||
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
/// Constructor (output gradient (Dy) OperandA ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad(
|
||||
Params const ¶ms,
|
||||
ConvProblemSize const &problem_size,
|
||||
Element const *ptr,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
tile_access_iterator_(
|
||||
params,
|
||||
problem_size,
|
||||
ptr,
|
||||
thread_idx,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
start_r, start_s,
|
||||
threadblock_offset) { }
|
||||
|
||||
/// Constructor (filter (w) OperandB ctor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileIteratorStridedDgrad(
|
||||
Params const ¶ms,
|
||||
@ -210,7 +230,12 @@ public:
|
||||
int start_r, int start_s,
|
||||
MatrixCoord const &threadblock_offset = MatrixCoord()
|
||||
):
|
||||
tile_access_iterator_(params, problem_size, ptr, thread_idx, start_r, start_s, threadblock_offset) { }
|
||||
tile_access_iterator_(params,
|
||||
problem_size,
|
||||
ptr,
|
||||
thread_idx,
|
||||
start_r, start_s,
|
||||
threadblock_offset) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) {
|
||||
|
||||
@ -31,6 +31,12 @@
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_NAMESPACE
|
||||
#define cutlass CUTLASS_NAMESPACE
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
||||
@ -174,12 +174,12 @@ public:
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
if (Scale == ScaleType::Nothing)
|
||||
return destination_converter(converted_accumulator);
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
|
||||
@ -309,9 +309,12 @@ struct DefaultEpilogueTensorOp {
|
||||
kElementsPerAccess
|
||||
>::Type;
|
||||
|
||||
static bool const UseCUDAStore = platform::is_same<ElementOutput, double>::value;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput
|
||||
ElementOutput,
|
||||
UseCUDAStore
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
|
||||
|
||||
@ -62,7 +62,8 @@ namespace threadblock {
|
||||
///
|
||||
template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_ ///< Element data type
|
||||
typename Element_, ///< Element data type
|
||||
bool UseCUDAStore = false
|
||||
>
|
||||
class PredicatedTileIterator {
|
||||
public:
|
||||
@ -341,10 +342,17 @@ public:
|
||||
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
if (UseCUDAStore) {
|
||||
if (guard) {
|
||||
memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
|
||||
}
|
||||
} else {
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
}
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
|
||||
@ -222,6 +222,7 @@ public:
|
||||
Element *pointer,
|
||||
TensorCoord extent,
|
||||
int thread_idx,
|
||||
FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod,
|
||||
int start_r, int start_s,
|
||||
TensorCoord threadblock_offset = TensorCoord()
|
||||
):
|
||||
@ -238,9 +239,12 @@ public:
|
||||
s = (params_.problem_size.S - 1 - s);
|
||||
}
|
||||
|
||||
// check if start_h_ and start_w_ are always positive
|
||||
start_h_ = std::abs((params_.problem_size.pad_h - r) % params_.problem_size.stride_h);
|
||||
start_w_ = std::abs((params_.problem_size.pad_w - s) % params_.problem_size.stride_w);
|
||||
// compute starting coordinates in Dx start_h_ and start_w_
|
||||
strided_dgrad_starting_coords(
|
||||
params_.problem_size,
|
||||
stride_h_divmod, stride_w_divmod,
|
||||
r, s,
|
||||
start_h_, start_w_);
|
||||
|
||||
p_ = (params_.problem_size.H - start_h_ + params_.problem_size.stride_h - 1) / params_.problem_size.stride_h;
|
||||
q_ = (params_.problem_size.W - start_w_ + params_.problem_size.stride_w - 1) / params_.problem_size.stride_w;
|
||||
|
||||
@ -256,20 +256,7 @@ public:
|
||||
|
||||
|
||||
int offset = n * Detail::kLanesInQuad + pointer_offset / Policy::kElementsPerAccess;
|
||||
#if 0
|
||||
// Using inline PTX to avoid generic memory
|
||||
AccessType *smem_ptr = pointers_[ptr_idx];
|
||||
smem_ptr[offset] = frag_ptr[n];
|
||||
#else
|
||||
uint32_t smem_addr = arch::cutlass_get_smem_pointer(ptr);
|
||||
uint32_t const *data = reinterpret_cast<uint32_t const *>(frag_ptr + n);
|
||||
uint32_t offset_in_bytes = offset * sizeof(AccessType);
|
||||
|
||||
asm volatile(
|
||||
"{ .reg .u32 smem_ptr; add.u32 smem_ptr, %0, %1; st.shared.v2.u32 [smem_ptr], {%2, %3}; }\n"
|
||||
: : "r"(smem_addr), "r"(offset_in_bytes), "r"(data[0]), "r"(data[1])
|
||||
);
|
||||
#endif
|
||||
ptr[offset] = frag_ptr[n];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -455,14 +455,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -445,14 +445,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -423,14 +423,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -437,14 +437,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
@ -438,14 +438,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@ -352,14 +352,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
|
||||
|
||||
@ -325,14 +325,6 @@ public:
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<GemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
|
||||
@ -103,8 +103,8 @@ template <
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for SM80 out-of-bound cp.async
|
||||
bool UseZfill = false,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
///
|
||||
typename Enable = void>
|
||||
struct DefaultGemmWithKReduction {
|
||||
@ -116,7 +116,7 @@ struct DefaultGemmWithKReduction {
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kReduceKForA, arch::Sm80,
|
||||
ThreadblockShape, WarpShape, InstructionShape, Stages,
|
||||
Operator, false, UseZfill>::ThreadblockMma;
|
||||
Operator, false, SharedMemoryClear>::ThreadblockMma;
|
||||
|
||||
static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -130,7 +130,6 @@ struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape,
|
||||
InstructionShape, 2, Operator, false, SharedMemoryClearOption::kNone> {
|
||||
|
||||
|
||||
static_assert(platform::is_same<LayoutC, layout::RowMajor>::value
|
||||
|| platform::is_same<LayoutC, layout::AffineRankN<2>>::value,
|
||||
"simt epilogue must be row major");
|
||||
|
||||
@ -141,8 +141,8 @@ struct DefaultMmaWithReductionCore {
|
||||
using SmemLayoutB = typename Base::SmemLayoutB;
|
||||
using WarpCount = typename Base::WarpCount;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = cutlass::arch::CacheOperation::Always;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaWithReductionTensorOp<
|
||||
|
||||
@ -82,9 +82,10 @@ template <
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Use zfill or predicate for SM80 out-of-bound cp.async
|
||||
bool UseZfill = false
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone
|
||||
>
|
||||
struct DefaultMmaWithReduction {
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
@ -122,7 +123,7 @@ struct DefaultMmaWithReduction {
|
||||
typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
|
||||
MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, Stages, UseZfill>;
|
||||
typename MmaCore::MmaPolicy, Stages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -303,10 +303,8 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
@ -447,10 +445,8 @@ public:
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
@ -558,10 +554,8 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
|
||||
@ -231,10 +231,8 @@ public:
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations <= 1) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
@ -302,10 +300,8 @@ public:
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations <= 2) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
warp_mma(accum, warp_frag_A[warp_mma_k % 2],
|
||||
|
||||
@ -370,12 +370,10 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A_real.set_iteration_index(0);
|
||||
iterator_A_imag.set_iteration_index(0);
|
||||
@ -501,12 +499,10 @@ public:
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
// Start issuing the first group of the next stage outside of the mainloop
|
||||
copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag);
|
||||
@ -611,12 +607,10 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_real.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
warp_mma_planar_complex(
|
||||
|
||||
@ -308,13 +308,11 @@ public:
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations <= 1) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
iterator_B_real.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
@ -392,12 +390,10 @@ public:
|
||||
++iterator_B_imag;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations <= 2) {
|
||||
iterator_A_real.clear_mask();
|
||||
iterator_A_imag.clear_mask();
|
||||
iterator_B_real.clear_mask();
|
||||
iterator_B_imag.clear_mask();
|
||||
}
|
||||
iterator_A_real.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_A_imag.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B_real.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B_imag.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
warp_mma_planar_complex(
|
||||
|
||||
@ -196,10 +196,8 @@ public:
|
||||
Operator warp_mma;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations <= 1) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -247,10 +245,8 @@ public:
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations <= 2) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -379,11 +379,9 @@ public:
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_E.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
@ -500,11 +498,9 @@ public:
|
||||
++this->warp_tile_iterator_B_;
|
||||
++this->warp_tile_iterator_E_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_E.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
@ -637,11 +633,9 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
iterator_E.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_E.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
|
||||
@ -78,7 +78,7 @@ template <
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
bool UseZfill = false,
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaWithReductionMultistage :
|
||||
@ -234,7 +234,7 @@ public:
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (UseZfill) {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
} else {
|
||||
@ -269,7 +269,7 @@ public:
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (UseZfill) {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
@ -302,16 +302,14 @@ public:
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Issue several complete stages
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1;
|
||||
++stage, --gemm_k_iterations) {
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
@ -403,10 +401,8 @@ public:
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
@ -515,10 +511,8 @@ public:
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
if (gemm_k_iterations == 0) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B.clear_mask();
|
||||
}
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
|
||||
// Do any conversions feeding the first stage at the end of the loop so
|
||||
@ -532,7 +526,7 @@ public:
|
||||
|
||||
}
|
||||
|
||||
if (UseZfill) {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
|
||||
@ -49,7 +49,6 @@ class MmaTensorOpFragmentIterator;
|
||||
|
||||
|
||||
// Partial specialization for col-major accumulator tile
|
||||
// And Element type is the same as Accumulator Element type
|
||||
|
||||
template <
|
||||
/// Shape of warp tile to load (concept: MatrixShape)
|
||||
@ -58,13 +57,15 @@ template <
|
||||
typename AccumulatorShape_,
|
||||
/// KBlocks columns to compute residual
|
||||
int KBlocksColumn_,
|
||||
/// Accumulator Element type
|
||||
typename ElementAccumulator_,
|
||||
/// Element type
|
||||
typename Element_,
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
typename InstructionShape_,
|
||||
/// Output operation on fragment
|
||||
typename OutputOp_>
|
||||
class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Element_, Element_,
|
||||
class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, ElementAccumulator_, Element_,
|
||||
cutlass::layout::ColumnMajor,
|
||||
InstructionShape_, OutputOp_> {
|
||||
public:
|
||||
@ -78,6 +79,9 @@ class MmaTensorOpFragmentIterator<Shape_, AccumulatorShape_, KBlocksColumn_, Ele
|
||||
/// KBlocks columns to compute residual
|
||||
static int const kKBlockColumn = KBlocksColumn_;
|
||||
|
||||
/// Accumulator Element type
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
/// Element type
|
||||
using Element = Element_;
|
||||
|
||||
@ -143,13 +147,14 @@ public:
|
||||
using Fragment = Array<Element, Shape::kCount / kThreads>;
|
||||
|
||||
/// Accumulator Fragment object
|
||||
using AccumulatorFragment = Array<Element, AccumulatorShape::kCount / kThreads>;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, AccumulatorShape::kCount / kThreads>;
|
||||
|
||||
|
||||
private:
|
||||
|
||||
/// Internal access type
|
||||
using AccessType = Array<Element, kElementsPerAccess>;
|
||||
using AccessType = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using FragmentAccessType = Array<Element, kElementsPerAccess>;
|
||||
|
||||
private:
|
||||
//
|
||||
@ -203,10 +208,10 @@ public:
|
||||
if (output_op.is_source_needed()) //beta must be zero
|
||||
assert(0);
|
||||
|
||||
AccessType src_fragment;
|
||||
FragmentAccessType src_fragment;
|
||||
src_fragment.clear();
|
||||
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
FragmentAccessType *frag_ptr = reinterpret_cast<FragmentAccessType *>(&frag);
|
||||
|
||||
int index = index_ * MmaIterations::kCount;
|
||||
|
||||
|
||||
@ -14030,15 +14030,15 @@ struct Matrix<Element_, 4, 4> {
|
||||
|
||||
/// Returns a perspective projection matrix typical of OpenGL applications
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Matrix perspective(Element near, Element far, Element fovH, Element fovV) {
|
||||
static Matrix perspective(Element near_plane, Element far_plane, Element fovH, Element fovV) {
|
||||
Element aspect = fovH / fovV;
|
||||
Element f = Element(cos(fovV)) / Element(fovH);
|
||||
Element Q = near - far;
|
||||
Element Q = near_plane - far_plane;
|
||||
|
||||
return Matrix(
|
||||
f / aspect, 0, 0, 0,
|
||||
0, f, 0, 0,
|
||||
0, 0, (near + far) / Q, Element(2) * far * near / Q,
|
||||
0, 0, (near_plane + far_plane) / Q, Element(2) * far_plane * near_plane / Q,
|
||||
0, 0, -1, 0
|
||||
);
|
||||
}
|
||||
|
||||
@ -245,10 +245,10 @@ class PredicatedTileAccessIteratorPredicates {
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
void clear_mask(bool enable = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kPredicateWordCount; ++i) {
|
||||
predicates_[i] = 0u;
|
||||
predicates_[i] = enable ? 0u : predicates_[i];
|
||||
}
|
||||
|
||||
}
|
||||
@ -551,8 +551,8 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
the_predicates.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
the_predicates.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -741,7 +741,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -922,7 +922,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1224,7 +1224,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRankN<2>,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { the_predicates.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1401,7 +1401,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2ColumnMa
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1578,7 +1578,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2RowMajor
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1764,7 +1764,7 @@ class PredicatedTileAccessIterator<Shape_, Element_,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1948,7 +1948,7 @@ class PredicatedTileAccessIterator<Shape_, Element_,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -403,10 +403,10 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::PitchLi
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
void clear_mask(bool enable = true) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kPredicateWordCount; ++i) {
|
||||
predicates_[i] = 0u;
|
||||
predicates_[i] = enable ? 0u : predicates_[i];
|
||||
}
|
||||
|
||||
}
|
||||
@ -617,7 +617,7 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::ColumnM
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -796,7 +796,7 @@ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::RowMajo
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -288,7 +288,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { address_iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -530,8 +530,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -738,8 +738,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -946,7 +946,7 @@ class PredicatedTileIterator<Shape_, Element_, layout::AffineRankN<2>, AdvanceRa
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { address_iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1184,8 +1184,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -1388,8 +1388,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -1600,7 +1600,7 @@ class PredicatedTileIterator<Shape_, Element_,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -1785,7 +1785,7 @@ class PredicatedTileIterator<Shape_, Element_,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -293,7 +293,7 @@ class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::PitchLinear,
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() { address_iterator_.clear_mask(); }
|
||||
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -525,8 +525,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
@ -721,8 +721,8 @@ public:
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE
|
||||
void clear_mask() {
|
||||
iterator_.clear_mask();
|
||||
void clear_mask(bool enable = true) {
|
||||
iterator_.clear_mask(enable);
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
|
||||
@ -103,7 +103,6 @@ Profiling:
|
||||
|
||||
--profiling-enabled=<bool> If true, profiling is actually conducted.
|
||||
|
||||
|
||||
Verification:
|
||||
--verification-enabled=<bool> Whether to perform verification checks.
|
||||
|
||||
|
||||
@ -206,9 +206,12 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS="50;53" # compiles for NVIDIA Maxwell G
|
||||
|
||||
## Clang
|
||||
|
||||
For experimental purposes, CUTLASS may be compiled with
|
||||
[clang 8.0](https://github.com/llvm/llvm-project/releases/download/llvmorg-8.0.1/clang+llvm-8.0.1-amd64-unknown-freebsd11.tar.xz) using the
|
||||
For experimental purposes, CUTLASS has been verified to compile with the following versions of Clang and CUDA.
|
||||
|
||||
* [clang 8.0](https://github.com/llvm/llvm-project/releases/download/llvmorg-8.0.1/clang+llvm-8.0.1-amd64-unknown-freebsd11.tar.xz) using the
|
||||
[CUDA 10.0 Toolkit](https://developer.nvidia.com/cuda-10.0-download-archive).
|
||||
* [clang release/13.x](https://github.com/llvm/llvm-project/tree/release/13.x) using [CUDA 11.4](https://developer.nvidia.com/cuda-toolkit-archive)
|
||||
|
||||
At this time, compiling with clang enables the CUTLASS SIMT GEMM kernels (sgemm, dgemm, hgemm, igemm)
|
||||
but does not enable TensorCores.
|
||||
|
||||
@ -216,6 +219,8 @@ but does not enable TensorCores.
|
||||
$ mkdir build && cd build
|
||||
|
||||
$ cmake -DCUDA_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ ..
|
||||
# Add -DCMAKE_CXX_FLAGS=-D__NV_NO_HOST_COMPILER_CHECK=1 -DCMAKE_CUDA_FLAGS=-D__NV_NO_HOST_COMPILER_CHECK=1 if compiler
|
||||
# checks fail during CMake configuration.
|
||||
|
||||
$ make test_unit -j
|
||||
```
|
||||
|
||||
@ -26,9 +26,9 @@
|
||||
#pragma once
|
||||
#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for vistual studio */
|
||||
|
||||
#pragma diag_suppress boolean_controlling_expr_is_constant
|
||||
#pragma nv_diag_suppress boolean_controlling_expr_is_constant
|
||||
#include <gtest/gtest.h>
|
||||
#pragma diag_warning boolean_controlling_expr_is_constant
|
||||
#pragma nv_diag_warning boolean_controlling_expr_is_constant
|
||||
#pragma warning( disable : 4503)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -281,6 +281,22 @@ struct TestbedConv2dProblemSizes {
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 8, 8, 8}, // input size (NHWC)
|
||||
{8, 3, 3, 8}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 16, 16, 8}, // input size (NHWC)
|
||||
{8, 3, 3, 8}, // filter size (KRSC)
|
||||
{3, 3, 3, 3}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////
|
||||
// Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1)
|
||||
////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -389,6 +405,25 @@ struct TestbedConv2dProblemSizes {
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////
|
||||
// Medium input size padding > stride, asymmetric filter, padding and striding
|
||||
////////////////////////////////////////////////////////////////////////////////////
|
||||
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 27, 31, 256}, // input size (NHWC)
|
||||
{512, 3, 3, 256}, // filter size (KRSC)
|
||||
{5, 5, 7, 7}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 4}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 27, 35, 256}, // input size (NHWC)
|
||||
{512, 7, 5, 256}, // filter size (KRSC)
|
||||
{11, 11, 7, 7}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 5}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////
|
||||
// Medium input size *mixed* stride (1, 2) and (2, 1),
|
||||
// filter (3, 3), default padding
|
||||
@ -419,6 +454,14 @@ struct TestbedConv2dProblemSizes {
|
||||
{2, 2}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 32, 32, 16}, // input size (NHWC)
|
||||
{32, 3, 3, 16}, // filter size (KRSC)
|
||||
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
||||
{6, 2}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{32, 32, 32, 32}, // input size (NHWC)
|
||||
|
||||
@ -78,23 +78,15 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32n
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
|
||||
#if 0 // run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 56, 56, 8}, // input size (NHWC)
|
||||
{8, 1, 1, 8}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{2, 2}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
{1, 4, 4, 8}, // input size (NHWC)
|
||||
{8, 3, 3, 8}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{3, 3}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{1, 55, 55, 8}, // input size (NHWC)
|
||||
{8, 1, 1, 8}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{2, 2}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
#endif
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <stddef.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace nvrtc {
|
||||
|
||||
@ -1312,6 +1312,7 @@ def GenerateSM80_TensorOp_16816(manifest, args):
|
||||
TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -1392,7 +1393,7 @@ def GenerateSM80_SparseTensorOp_16832(manifest, args):
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [8, 4, 2]
|
||||
alignment_constraints = [8]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
@ -1967,6 +1968,8 @@ def GenerateSM80_TensorOp_1688(manifest, args):
|
||||
TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -2051,6 +2054,8 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args):
|
||||
TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||
@ -2100,7 +2105,7 @@ def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args):
|
||||
max_cc = 1024
|
||||
max_cc_smem_limited = 80
|
||||
|
||||
alignment_constraints = [4, 2, 1]
|
||||
alignment_constraints = [4]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
|
||||
@ -20,7 +20,6 @@ class EmitOperationKindLibrary:
|
||||
self.generated_path = generated_path
|
||||
self.kind = kind
|
||||
self.args = args
|
||||
|
||||
self.emitters = {
|
||||
OperationKind.Gemm: EmitGemmConfigurationLibrary
|
||||
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
|
||||
@ -347,7 +346,7 @@ class Manifest:
|
||||
|
||||
with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter:
|
||||
for operation_kind, configurations in self.operations.items():
|
||||
iface_emitter.emit(OperationKindNames[operation_kind])
|
||||
iface_emitter.emit(OperationKindNames[operation_kind])
|
||||
|
||||
source_files += iface_emitter.source_files
|
||||
|
||||
|
||||
@ -186,12 +186,6 @@ public:
|
||||
GemmUniversalConfiguration const &config = *static_cast<GemmUniversalConfiguration const *>(host_workspace);
|
||||
GemmUniversalArguments const &args = *static_cast<GemmUniversalArguments const *>(arguments);
|
||||
|
||||
ElementCompute alpha;
|
||||
ElementCompute beta;
|
||||
|
||||
alpha = *static_cast<ElementCompute const *>(args.alpha);
|
||||
beta = *static_cast<ElementCompute const *>(args.beta);
|
||||
|
||||
TensorRefA ref_A{static_cast<ElementA *>(const_cast<void *>(args.A)), LayoutA(int(config.lda))};
|
||||
TensorRefB ref_B{static_cast<ElementB *>(const_cast<void *>(args.B)), LayoutB(int(config.ldb))};
|
||||
TensorRefC ref_C{static_cast<ElementC *>(const_cast<void *>(args.C)), LayoutC(int(config.ldc))};
|
||||
@ -212,16 +206,16 @@ public:
|
||||
InnerProductOp
|
||||
>(
|
||||
config.problem_size,
|
||||
alpha,
|
||||
*static_cast<ElementCompute const *>(args.alpha),
|
||||
ref_A,
|
||||
kTransformA,
|
||||
ref_B,
|
||||
kTransformB,
|
||||
beta,
|
||||
*static_cast<ElementCompute const *>(args.beta),
|
||||
ref_C,
|
||||
ref_D,
|
||||
ElementAccumulator(),
|
||||
config.batch_count,
|
||||
((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1),
|
||||
args.batch_stride_A,
|
||||
args.batch_stride_B,
|
||||
args.batch_stride_C,
|
||||
@ -245,16 +239,16 @@ public:
|
||||
InnerProductOp
|
||||
>(
|
||||
config.problem_size,
|
||||
alpha,
|
||||
*static_cast<ElementCompute const *>(args.alpha),
|
||||
ref_A,
|
||||
kTransformA,
|
||||
ref_B,
|
||||
kTransformB,
|
||||
beta,
|
||||
*static_cast<ElementCompute const *>(args.beta),
|
||||
ref_C,
|
||||
ref_D,
|
||||
ElementAccumulator(),
|
||||
config.batch_count,
|
||||
((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1),
|
||||
args.batch_stride_A,
|
||||
args.batch_stride_B,
|
||||
args.batch_stride_C,
|
||||
@ -263,7 +257,7 @@ public:
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
};
|
||||
|
||||
@ -791,7 +791,7 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
handle.set_provider(provider);
|
||||
|
||||
Status status = handle.gemm_universal(
|
||||
library::GemmUniversalMode::kGemm,
|
||||
problem_.mode,
|
||||
gemm_workspace_.configuration.problem_size.m(),
|
||||
gemm_workspace_.configuration.problem_size.n(),
|
||||
gemm_workspace_.configuration.problem_size.k(),
|
||||
|
||||
@ -29,6 +29,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace profiler {
|
||||
|
||||
@ -425,7 +425,9 @@ void Options::Profiling::print_usage(std::ostream &out) const {
|
||||
<< " Number of ms to sleep between profiling periods (ms).\n\n"
|
||||
|
||||
<< " --profiling-enabled=<bool> "
|
||||
<< " If true, profiling is actually conducted.\n\n";
|
||||
<< " If true, profiling is actually conducted.\n\n"
|
||||
|
||||
;
|
||||
}
|
||||
|
||||
void Options::Profiling::print_options(std::ostream &out, int indent) const {
|
||||
|
||||
@ -32,6 +32,8 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/******************************************************************************
|
||||
|
||||
Loading…
Reference in New Issue
Block a user