diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c13b3a8..f12294d2 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -197,9 +197,19 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}) if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) -list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1) + message(STATUS "Enable caching of reference results in conv unit tests") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1) endif() + +set(CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED ON CACHE BOOL "Enable/Disable rigorous conv problem sizes in conv unit tests") + +if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED) + message(STATUS "Enable rigorous conv problem sizes in conv unit tests") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1) +endif() + + # # CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations. # diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu index 1b64c86c..d6c70006 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu @@ -296,7 +296,7 @@ int main() { return -1; } - if (!((props.major * 10 + props.minor) >= 80)) { + if (props.major * 10 + props.minor < 80) { std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu b/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu index a0bc28db..cce5edda 100644 --- a/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu +++ b/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu @@ -233,7 +233,7 @@ int run() { tensor_b.device_ref(), tensor_c_bias.device_ref(), tensor_ref_d.device_ref(), - alpha, 0 + alpha, ElementComputeEpilogue(0) ); // Wait for kernels to finish diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6663d37a..90a0e9b2 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -45,6 +45,7 @@ function(cutlass_example_add_executable NAME) PRIVATE CUTLASS cutlass_tools_util_includes + $<$:nvidia::cublas> ) target_include_directories( diff --git a/include/cutlass/arch/wmma_sm75.h b/include/cutlass/arch/wmma_sm75.h index c06c051f..81febe21 100644 --- a/include/cutlass/arch/wmma_sm75.h +++ b/include/cutlass/arch/wmma_sm75.h @@ -109,6 +109,7 @@ struct Wmma< FragmentB const &B, FragmentC const &C) const { nvcuda::wmma::mma_sync(D, A, B, C); + } #else @@ -186,7 +187,6 @@ struct Wmma< FragmentA const &A, FragmentB const &B, FragmentC const &C) const { - nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); } diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index b7e447dd..25406616 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -109,6 +109,12 @@ static char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED +#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 +#endif + + // CUDA 10.1 introduces the mma instruction #if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) #define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 6f53375c..a4e73b17 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -58,6 +58,7 @@ struct Identity { /// ReLu operator - propagates NaNs template struct ReLu { + static const bool kIsHeavy=false; CUTLASS_HOST_DEVICE T operator()(T const & threshold, T value) const { if (value < threshold) { @@ -76,6 +77,7 @@ struct ReLu { template struct ReLu> { + static const bool kIsHeavy=false; CUTLASS_HOST_DEVICE Array operator()(T const & threshold, Array const &frag) const { Array result; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index acc84977..389eb26e 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -201,8 +201,8 @@ public: CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]); - result_Z[i] = z; - result_T[i] = skip_elementwise_ ? z : elementwise_op(z); + result_T[i] = z; + result_Z[i] = skip_elementwise_ ? z : elementwise_op(z); } NumericArrayConverter convert_z; @@ -230,8 +230,8 @@ public: CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]); - result_Z[i] = z; - result_T[i] = skip_elementwise_ ? z : elementwise_op(z); + result_T[i] = z; + result_Z[i] = skip_elementwise_ ? z : elementwise_op(z); } NumericArrayConverter convert_z; diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index a6152a23..71209b26 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -306,6 +306,7 @@ public: /// Debug printing CUTLASS_DEVICE static void print() { +#if 0 printf("BroadcastDetail {\n"); printf( " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" @@ -321,6 +322,7 @@ public: StorageShape::kCount ); printf("};\n"); +#endif } }; diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h index aad6b405..3a4f3ac9 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h @@ -212,6 +212,7 @@ public: /// Debug printing CUTLASS_DEVICE static void print() { +#if 0 printf("ReductionDetail {\n"); printf( " kElementsPerAccess:%d\nkColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" @@ -228,6 +229,7 @@ public: StorageShape::kCount ); printf("};\n"); +#endif } }; diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 18afc72d..3ddf61ff 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -363,8 +363,13 @@ tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { CUTLASS_HOST_DEVICE tfloat32_t operator-(tfloat32_t const& lhs) { - float x = -reinterpret_cast(lhs); - return *reinterpret_cast(&x); + union u_tff32 { + float val_f32; + tfloat32_t val_tf; + CUTLASS_HOST_DEVICE u_tff32() : val_f32(0) { } + }; + union u_tff32 x; x.val_f32 = -reinterpret_cast(lhs); + return x.val_tf; } CUTLASS_HOST_DEVICE diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 4120ec0d..c3edf277 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -35,6 +35,7 @@ target_link_libraries( cutlass_tools_util_includes $<$:nvidia::cublas> gtest + cudart ) cutlass_add_library( diff --git a/test/unit/common/cutlass_unit_test.h b/test/unit/common/cutlass_unit_test.h index 3259a3ba..8b8e1927 100644 --- a/test/unit/common/cutlass_unit_test.h +++ b/test/unit/common/cutlass_unit_test.h @@ -31,6 +31,8 @@ #pragma nv_diag_warning boolean_controlling_expr_is_constant #pragma warning( disable : 4503) +#include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// /// Sets flags for Unit test @@ -38,6 +40,13 @@ void FilterArchitecture(); ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order +// of problem sizes run by CUTLASS unit tests +int CutlassUnitTestProblemCount(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + + // active test macro #define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 42a39d4e..73c2f934 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -91,3 +91,14 @@ void FilterArchitecture() { } ///////////////////////////////////////////////////////////////////////////////////////////////// + +int CutlassUnitTestProblemCount() { + if(const char* problem_count = std::getenv("CUTLASS_UNIT_TEST_PROBLEM_COUNT")) { + + return std::stoi(problem_count); + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu index cd1b22d3..97f8609f 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm50.cu @@ -35,7 +35,6 @@ #include "conv2d_testbed.h" - //////////////////////////////////////////////////////////////////////////////// TEST(SM50_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32, 64x64_8x2_32x64x8) { diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu index 1ec56cb5..64710edb 100644 --- a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu @@ -56,7 +56,7 @@ TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nh cutlass::half_t, cutlass::half_t, 8, - cutlass::epilogue::thread::GELU_taylor + cutlass::epilogue::thread::ReLu >; /// Device-level Conv2d instance diff --git a/test/unit/conv/device/conv2d_problems.h b/test/unit/conv/device/conv2d_problems.h index 4fef15c8..651d5bed 100644 --- a/test/unit/conv/device/conv2d_problems.h +++ b/test/unit/conv/device/conv2d_problems.h @@ -36,8 +36,6 @@ #include "cutlass/conv/convolution.h" #include "cutlass/conv/conv2d_problem_size.h" -#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 1 - namespace test { namespace conv { namespace device { diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 318db31a..1f441c23 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -99,6 +99,8 @@ public: cutlass::HostTensor tensor_D_computed; cutlass::HostTensor tensor_D_reference; + int tested_problem_count; + public: TestbedConv2d( @@ -107,7 +109,7 @@ public: cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) { } @@ -220,7 +222,10 @@ public: return true; } -#if 0 //display conv2d problem size for debugging + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging std::cout << problem_size << std::endl << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl @@ -537,78 +542,96 @@ bool TestAllConv2d( // Vector of conv2d problem sizes to avoid duplicate runs Conv2dProblemVector conv_tested_sizes; - Conv2dProblemVector const *problem_vectors[] = { - &conv_test_sizes, // run user specified sizes - &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes - //&conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes + // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes) + std::vector problem_vectors = { + conv_test_sizes, // run user specified sizes + conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED - &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled + conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled #endif }; + // Flatten 2D problem_vectors into a 1D problem_sizes + std::vector problem_sizes; + for (auto problem_vector : problem_vectors) { + for(auto conv_problem : problem_vector) { + problem_sizes.push_back(conv_problem); + } + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient) + // run the most rigorous problem size first + if (CutlassUnitTestProblemCount()) { + std::reverse(problem_sizes.begin(), problem_sizes.end()); + } + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) - for (Conv2dProblemVector const * problem_vector : problem_vectors) { + for(auto conv_problem : problem_sizes) { - // Run conv testbed on default convolution sizes - for(auto conv_problem : *problem_vector) { + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } - // Skip blacklist and avoid duplicate problem sizes - if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || - std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { continue; } + } - // - // Procedurally disable certain cases - // + // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} + // Although strided dgrad works for all stride combinations, we are only going + // to run strided dgrad for non-unity strides + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); - // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kUnity)) { - if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } - - // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} - // Although strided dgrad works for all stride combinations, we are only going - // to run strided dgrad for non-unity strides - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == - cutlass::conv::StrideSupport::kStrided)) { - if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { - continue; - } - } - - // - // Test - // - // push back tested problem size to avoid re-running duplicates - conv_tested_sizes.push_back(conv_problem); - - // test mode = xcross - passed = testbed.run( - conv_problem, - cutlass::conv::SplitKMode::kSerial); + if (!passed) { + return false; + } - if (!passed) { - return false; - } - - // test mode = convolution - passed = testbed.run( - conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - - if (!passed) { - return false; - } + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; } } + // CUTLASS DGRAD's *strided* specialization does not support split-k mode if ((ImplicitGemm::kConvolutionalOperator == @@ -677,6 +700,12 @@ bool TestAllConv2d( if (!passed) { return false; } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } } } } diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index bb79a1cf..bd9596a7 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -23,7 +23,11 @@ * **************************************************************************************************/ /*! \file - \brief Implicit GEMM testbed + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. */ #pragma once @@ -53,7 +57,46 @@ namespace test { namespace conv { namespace device { +///////////////////////////////////////////////////////////////////////////////////////////////// + template +struct Conv2dWithBroadcastReferenceOp { + + using OutputOp = typename Conv2d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv2dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv2d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k]) +// +// Z[n, p, q, k] = Elementwise(T[n, p, q, k]) +// + +template < + typename Conv2d, + typename ReferenceOp = Conv2dWithBroadcastReferenceOp +> class TestbedConv2dWithBroadcast { public: @@ -66,6 +109,8 @@ public: using ElementAccumulator = typename Conv2d::ElementAccumulator; using ElementCompute = typename Conv2d::ElementCompute; using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; @@ -80,8 +125,13 @@ public: cutlass::HostTensor tensor_A; cutlass::HostTensor tensor_B; cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D_computed; - cutlass::HostTensor tensor_D_reference; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast public: @@ -147,18 +197,44 @@ public: tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Broadcast.resize({ + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); initialize_tensor(tensor_A.host_view(), init_A, seed); initialize_tensor(tensor_B.host_view(), init_B, seed * 17); initialize_tensor(tensor_C.host_view(), init_C, seed * 39); - + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k})); + } + } + } + } + tensor_A.sync_device(); tensor_B.sync_device(); tensor_C.sync_device(); - tensor_D_computed.sync_device(); - tensor_D_reference.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); } bool sufficient() const { @@ -215,18 +291,21 @@ public: // configure the operator Conv2d conv2d_op; - typename Conv2d::Arguments conv2d_args( problem_size, tensor_A.device_ref(), tensor_B.device_ref(), tensor_C.device_ref(), - tensor_D_computed.device_ref(), + tensor_Z_computed.device_ref(), {alpha, beta}, - split_k_mode + split_k_mode, + tensor_Broadcast.device_data(), + tensor_T_computed.device_data(), + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() ); - // find workspace requirement for parallel split-k reduction + // initialize the kernel size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); cutlass::device_memory::allocation workspace(workspace_size); @@ -239,22 +318,6 @@ public: return true; } - // conv2d operation with parallel split-k-mode - if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { - - // conv2d output is written to workspace in global memory - conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); - // accumulate mma for each cta in k-dimension (1.0 * A * B) - conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; - // update conv2d operator arguments - status = conv2d_op.update(conv2d_args, workspace.get()); - } - - EXPECT_TRUE(status == cutlass::Status::kSuccess); - if (status != cutlass::Status::kSuccess) { - return false; - } - // run conv2d operator status = conv2d_op(); @@ -269,52 +332,13 @@ public: EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); - tensor_D_computed.sync_host(); + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); // - // Reference check - support caching results + // Reference check // - CachedTestKey cached_test_key = CreateCachedConv2dWithBroadcastTestKey< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementAccumulator, - ElementCompute - >( - kConvolutionalOperator, - problem_size, - alpha, - beta, - tensor_A.host_view(), - tensor_B.host_view(), - tensor_C.host_view() - ); - - // - // Look for the cached key - // - - bool cached_result_loaded = false; - CachedTestResult cached_test_result; - - std::string conv2d_result_cache_name = - std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - CachedTestResultListing cached_results(conv2d_result_cache_name); - - auto cached = cached_results.find(cached_test_key); - - cached_result_loaded = cached.first; - if (cached_result_loaded) { - cached_test_result = cached.second; - } - } - - if (!cached_result_loaded) { - #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED cutlass::reference::device::Conv2d< @@ -322,22 +346,22 @@ public: LayoutA, ElementB, LayoutB, - ElementC, + ElementAccumulator, LayoutC, - ElementCompute, + ElementAccumulator, ElementAccumulator >( kConvolutionalOperator, problem_size, tensor_A.device_ref(), tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D_reference.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), alpha, beta); // sync host (copy device data to host) for dumping error output in case of mismatches - tensor_D_reference.sync_host(); + tensor_Y_reference.sync_host(); #else @@ -346,48 +370,50 @@ public: LayoutA, ElementB, LayoutB, - ElementC, + ElementAccumulator, LayoutC, - ElementCompute, + ElementAccumulator, ElementAccumulator >( kConvolutionalOperator, problem_size, tensor_A.host_ref(), tensor_B.host_ref(), - tensor_C.host_ref(), - tensor_D_reference.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), alpha, beta); #endif + ReferenceOp reference_op; - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - - cached_test_result.D = TensorHash(tensor_D_reference.host_view()); - - CachedTestResultListing cached_results(conv2d_result_cache_name); - - cached_results.append(cached_test_key, cached_test_result); - cached_results.write(conv2d_result_cache_name); + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + ElementZ z; + ElementT t; + + reference_op(z, t, tensor_Y_reference.at({n, p, q, k}), tensor_Broadcast.at({0, 0, 0, k})); + + tensor_Z_reference.at({n, p, q, k}) = z; + tensor_T_reference.at({n, p, q, k}) = t; + } + } } - } // if (!cached_result_loaded) - - - uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); - - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - passed = (tensor_D_hash == cached_test_result.D); - - EXPECT_EQ(tensor_D_hash, cached_test_result.D) - << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; } - else { - passed = cutlass::reference::host::TensorEquals( - tensor_D_computed.host_view(), - tensor_D_reference.host_view()); - } + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); EXPECT_TRUE(passed); @@ -435,14 +461,16 @@ public: << "\nA:\n" << tensor_A.host_view() << "\n" << "\nB:\n" << tensor_B.host_view() << "\n" << "\nC:\n" << tensor_C.host_view() << "\n" - << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" - << "\nD computed:\n" << tensor_D_computed.host_view() << "\n"; - + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; } return passed; } - }; ///////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -582,8 +610,7 @@ bool TestAllConv2dWithBroadcast( ); cutlass::conv::SplitKMode split_k_modes [] = { - cutlass::conv::SplitKMode::kSerial, - cutlass::conv::SplitKMode::kParallel, + cutlass::conv::SplitKMode::kSerial }; int split_k_slices[] = { diff --git a/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/test/unit/gemm/device/testbed_gemm_with_broadcast.h index ca443a81..d0fa8293 100644 --- a/test/unit/gemm/device/testbed_gemm_with_broadcast.h +++ b/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -69,11 +69,11 @@ struct GemmWithBroadcastReferenceOp { void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { - ElementCompute z_full = binary_op(gemm, bias); - Z = ElementZ(z_full); - - ElementCompute t_full = elementwise_op(z_full); + ElementCompute t_full = binary_op(gemm, bias); T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); } }; @@ -83,9 +83,9 @@ struct GemmWithBroadcastReferenceOp { // // Y = GEMM(AB, C) // -// Z[i, j] = ReductionOp(Y[i, j], Broadcast[i]) +// T[i, j] = ReductionOp(Y[i, j], Broadcast[i]) // -// T[i, j] = Elementwise(Z[i, j]) +// Z[i, j] = Elementwise(T[i, j]) // template < @@ -101,7 +101,6 @@ struct TestbedGemmWithBroadcast { using ElementZ = typename OutputOp::ElementZ; using ElementT = typename OutputOp::ElementT; - /// Initialization cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; @@ -343,7 +342,6 @@ struct TestbedGemmWithBroadcast { ReferenceOp reference_op; - // compute tensor Z and tensor T for (int m = 0; m < problem_size.m(); ++m) { for (int n = 0; n < problem_size.n(); ++n) { diff --git a/test/unit/util/CMakeLists.txt b/test/unit/util/CMakeLists.txt index f46eb30e..ee41d08d 100644 --- a/test/unit/util/CMakeLists.txt +++ b/test/unit/util/CMakeLists.txt @@ -24,3 +24,8 @@ cutlass_test_unit_add_executable( cutlass_test_unit_util tensor_reduce.cu ) + +cutlass_test_unit_add_executable( + cutlass_test_unit_levels + cutlass_test_levels.cu + ) diff --git a/test/unit/util/cutlass_test_levels.cu b/test/unit/util/cutlass_test_levels.cu new file mode 100644 index 00000000..9bf9faa3 --- /dev/null +++ b/test/unit/util/cutlass_test_levels.cu @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include + +#include "../common/cutlass_unit_test.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM75_CUTLASS_TEST, level_not_specified) { + + EXPECT_TRUE(true); +} + +TEST(SM80_CUTLASS_TEST, level_not_specified) { + + EXPECT_TRUE(true); +} + +CUTLASS_TEST_L0(SM75_CUTLASS_TEST, level0, { + + EXPECT_TRUE(true); +}) + +CUTLASS_TEST_L1(SM75_CUTLASS_TEST, level1, { + + EXPECT_TRUE(true); +}) + +CUTLASS_TEST_L2(SM75_CUTLASS_TEST, level2, { + + EXPECT_TRUE(true); +}) + +CUTLASS_TEST_L0(SM80_CUTLASS_TEST, level0, { + + EXPECT_TRUE(true); +}) + +CUTLASS_TEST_L1(SM80_CUTLASS_TEST, level1, { + + EXPECT_TRUE(true); +}) + +CUTLASS_TEST_L2(SM80_CUTLASS_TEST, level2, { + + EXPECT_TRUE(true); +}) +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 333dae7b..4b2c7805 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -1056,7 +1056,6 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args): min_cc = 75 max_cc = 1024 - alignment_constraints = [32,] for math_inst in math_instructions: @@ -1136,7 +1135,6 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args): min_cc = 75 max_cc = 1024 - alignment_constraints = [32,] for math_inst in math_instructions: @@ -1907,7 +1905,6 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args): min_cc = 80 max_cc = 1024 - alignment_constraints = [32,] for math_inst in math_instructions: diff --git a/tools/profiler/src/cublas_helpers.cu b/tools/profiler/src/cublas_helpers.cu index 94261e18..f1e481a8 100644 --- a/tools/profiler/src/cublas_helpers.cu +++ b/tools/profiler/src/cublas_helpers.cu @@ -102,6 +102,12 @@ bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID ele data_type = CUDA_R_16F; return true; + case library::NumericTypeID::kBF16: + break; + + case library::NumericTypeID::kTF32: + break; + case library::NumericTypeID::kF32: data_type = CUDA_R_32F; return true;