Updates to fused epilogue (#383)
* Enhancements and fixes to fused GEMM and Convolution epilogue. * Need to explicitly list cudart as unit test library dependency.
This commit is contained in:
parent
4e666e1dfd
commit
ec4f7e5194
@ -197,9 +197,19 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
|
||||
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
|
||||
message(STATUS "Enable caching of reference results in conv unit tests")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
|
||||
endif()
|
||||
|
||||
|
||||
set(CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED ON CACHE BOOL "Enable/Disable rigorous conv problem sizes in conv unit tests")
|
||||
|
||||
if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
|
||||
message(STATUS "Enable rigorous conv problem sizes in conv unit tests")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
#
|
||||
|
@ -296,7 +296,7 @@ int main() {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!((props.major * 10 + props.minor) >= 80)) {
|
||||
if (props.major * 10 + props.minor < 80) {
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
notSupported = true;
|
||||
|
@ -233,7 +233,7 @@ int run() {
|
||||
tensor_b.device_ref(),
|
||||
tensor_c_bias.device_ref(),
|
||||
tensor_ref_d.device_ref(),
|
||||
alpha, 0
|
||||
alpha, ElementComputeEpilogue(0)
|
||||
);
|
||||
|
||||
// Wait for kernels to finish
|
||||
|
@ -45,6 +45,7 @@ function(cutlass_example_add_executable NAME)
|
||||
PRIVATE
|
||||
CUTLASS
|
||||
cutlass_tools_util_includes
|
||||
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
|
@ -109,6 +109,7 @@ struct Wmma<
|
||||
FragmentB const &B,
|
||||
FragmentC const &C) const {
|
||||
nvcuda::wmma::mma_sync(D, A, B, C);
|
||||
|
||||
}
|
||||
|
||||
#else
|
||||
@ -186,7 +187,6 @@ struct Wmma<
|
||||
FragmentA const &A,
|
||||
FragmentB const &B,
|
||||
FragmentC const &C) const {
|
||||
|
||||
nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
|
||||
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
|
||||
}
|
||||
|
@ -109,6 +109,12 @@ static char const* cutlassGetStatusString(cutlass::Status status) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
|
||||
#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0
|
||||
#endif
|
||||
|
||||
|
||||
// CUDA 10.1 introduces the mma instruction
|
||||
#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0
|
||||
|
@ -58,6 +58,7 @@ struct Identity {
|
||||
/// ReLu operator - propagates NaNs
|
||||
template <typename T>
|
||||
struct ReLu {
|
||||
static const bool kIsHeavy=false;
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const & threshold, T value) const {
|
||||
if (value < threshold) {
|
||||
@ -76,6 +77,7 @@ struct ReLu {
|
||||
|
||||
template <typename T, int N>
|
||||
struct ReLu<Array<T, N>> {
|
||||
static const bool kIsHeavy=false;
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
|
||||
Array<T, N> result;
|
||||
|
@ -201,8 +201,8 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kElementsPerAccess; ++i) {
|
||||
ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
|
||||
result_Z[i] = z;
|
||||
result_T[i] = skip_elementwise_ ? z : elementwise_op(z);
|
||||
result_T[i] = z;
|
||||
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
|
||||
}
|
||||
|
||||
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
||||
@ -230,8 +230,8 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kElementsPerAccess; ++i) {
|
||||
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
|
||||
result_Z[i] = z;
|
||||
result_T[i] = skip_elementwise_ ? z : elementwise_op(z);
|
||||
result_T[i] = z;
|
||||
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
|
||||
}
|
||||
|
||||
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
|
||||
|
@ -306,6 +306,7 @@ public:
|
||||
/// Debug printing
|
||||
CUTLASS_DEVICE
|
||||
static void print() {
|
||||
#if 0
|
||||
printf("BroadcastDetail {\n");
|
||||
printf(
|
||||
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
||||
@ -321,6 +322,7 @@ public:
|
||||
StorageShape::kCount
|
||||
);
|
||||
printf("};\n");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -212,6 +212,7 @@ public:
|
||||
/// Debug printing
|
||||
CUTLASS_DEVICE
|
||||
static void print() {
|
||||
#if 0
|
||||
printf("ReductionDetail {\n");
|
||||
printf(
|
||||
" kElementsPerAccess:%d\nkColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
|
||||
@ -228,6 +229,7 @@ public:
|
||||
StorageShape::kCount
|
||||
);
|
||||
printf("};\n");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -363,8 +363,13 @@ tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
tfloat32_t operator-(tfloat32_t const& lhs) {
|
||||
float x = -reinterpret_cast<float const &>(lhs);
|
||||
return *reinterpret_cast<tfloat32_t *>(&x);
|
||||
union u_tff32 {
|
||||
float val_f32;
|
||||
tfloat32_t val_tf;
|
||||
CUTLASS_HOST_DEVICE u_tff32() : val_f32(0) { }
|
||||
};
|
||||
union u_tff32 x; x.val_f32 = -reinterpret_cast<float const &>(lhs);
|
||||
return x.val_tf;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
@ -35,6 +35,7 @@ target_link_libraries(
|
||||
cutlass_tools_util_includes
|
||||
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
|
||||
gtest
|
||||
cudart
|
||||
)
|
||||
|
||||
cutlass_add_library(
|
||||
|
@ -31,6 +31,8 @@
|
||||
#pragma nv_diag_warning boolean_controlling_expr_is_constant
|
||||
#pragma warning( disable : 4503)
|
||||
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Sets flags for Unit test
|
||||
@ -38,6 +40,13 @@ void FilterArchitecture();
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order
|
||||
// of problem sizes run by CUTLASS unit tests
|
||||
int CutlassUnitTestProblemCount();
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
// active test macro
|
||||
#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
|
||||
TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__
|
||||
|
@ -91,3 +91,14 @@ void FilterArchitecture() {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int CutlassUnitTestProblemCount() {
|
||||
if(const char* problem_count = std::getenv("CUTLASS_UNIT_TEST_PROBLEM_COUNT")) {
|
||||
|
||||
return std::stoi(problem_count);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -35,7 +35,6 @@
|
||||
|
||||
#include "conv2d_testbed.h"
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST(SM50_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32,
|
||||
64x64_8x2_32x64x8) {
|
||||
|
@ -56,7 +56,7 @@ TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nh
|
||||
cutlass::half_t,
|
||||
cutlass::half_t,
|
||||
8,
|
||||
cutlass::epilogue::thread::GELU_taylor<float>
|
||||
cutlass::epilogue::thread::ReLu<float>
|
||||
>;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
|
@ -36,8 +36,6 @@
|
||||
#include "cutlass/conv/convolution.h"
|
||||
#include "cutlass/conv/conv2d_problem_size.h"
|
||||
|
||||
#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 1
|
||||
|
||||
namespace test {
|
||||
namespace conv {
|
||||
namespace device {
|
||||
|
@ -99,6 +99,8 @@ public:
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
|
||||
|
||||
int tested_problem_count;
|
||||
|
||||
public:
|
||||
|
||||
TestbedConv2d(
|
||||
@ -107,7 +109,7 @@ public:
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {
|
||||
|
||||
}
|
||||
|
||||
@ -220,7 +222,10 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
#if 0 //display conv2d problem size for debugging
|
||||
// increment tested problem count run by the testbed
|
||||
tested_problem_count++;
|
||||
|
||||
#if 0 // display conv2d problem size for debugging
|
||||
std::cout << problem_size << std::endl
|
||||
<< "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
|
||||
<< "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
|
||||
@ -537,20 +542,32 @@ bool TestAllConv2d(
|
||||
// Vector of conv2d problem sizes to avoid duplicate runs
|
||||
Conv2dProblemVector conv_tested_sizes;
|
||||
|
||||
Conv2dProblemVector const *problem_vectors[] = {
|
||||
&conv_test_sizes, // run user specified sizes
|
||||
&conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
|
||||
//&conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
|
||||
// Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes)
|
||||
std::vector<Conv2dProblemVector> problem_vectors = {
|
||||
conv_test_sizes, // run user specified sizes
|
||||
conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
|
||||
//conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
|
||||
#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
|
||||
&conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
|
||||
conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
|
||||
#endif
|
||||
};
|
||||
|
||||
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
||||
for (Conv2dProblemVector const * problem_vector : problem_vectors) {
|
||||
// Flatten 2D problem_vectors into a 1D problem_sizes
|
||||
std::vector<cutlass::conv::Conv2dProblemSize> problem_sizes;
|
||||
for (auto problem_vector : problem_vectors) {
|
||||
for(auto conv_problem : problem_vector) {
|
||||
problem_sizes.push_back(conv_problem);
|
||||
}
|
||||
}
|
||||
|
||||
// Run conv testbed on default convolution sizes
|
||||
for(auto conv_problem : *problem_vector) {
|
||||
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient)
|
||||
// run the most rigorous problem size first
|
||||
if (CutlassUnitTestProblemCount()) {
|
||||
std::reverse(problem_sizes.begin(), problem_sizes.end());
|
||||
}
|
||||
|
||||
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
||||
for(auto conv_problem : problem_sizes) {
|
||||
|
||||
// Skip blacklist and avoid duplicate problem sizes
|
||||
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
|
||||
@ -607,9 +624,15 @@ bool TestAllConv2d(
|
||||
if (!passed) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts
|
||||
if (CutlassUnitTestProblemCount() &&
|
||||
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// CUTLASS DGRAD's *strided* specialization does not support split-k mode
|
||||
if ((ImplicitGemm::kConvolutionalOperator ==
|
||||
cutlass::conv::Operator::kDgrad) &&
|
||||
@ -677,6 +700,12 @@ bool TestAllConv2d(
|
||||
if (!passed) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts
|
||||
if (CutlassUnitTestProblemCount() &&
|
||||
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -23,7 +23,11 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implicit GEMM testbed
|
||||
\brief Implicit GEMM for fused epilogue broadcast testbed
|
||||
|
||||
Parallel split-k is not tested because we can just use regular conv kernel
|
||||
when we need to use parallel-splitk. Broadcast can happen in the reduction
|
||||
kernel.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@ -53,7 +57,46 @@ namespace test {
|
||||
namespace conv {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Conv2d>
|
||||
struct Conv2dWithBroadcastReferenceOp {
|
||||
|
||||
using OutputOp = typename Conv2d::EpilogueOutputOp;
|
||||
|
||||
using ElementCompute = typename OutputOp::ElementCompute;
|
||||
using ElementZ = typename OutputOp::ElementZ;
|
||||
using ElementT = typename OutputOp::ElementT;
|
||||
|
||||
typename OutputOp::BinaryOp binary_op;
|
||||
typename OutputOp::ElementwiseOp elementwise_op;
|
||||
|
||||
Conv2dWithBroadcastReferenceOp() { }
|
||||
|
||||
void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) {
|
||||
ElementCompute t_full = binary_op(conv2d, bias);
|
||||
T = ElementT(t_full);
|
||||
|
||||
ElementCompute z_full = elementwise_op(t_full);
|
||||
Z = ElementZ(z_full);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Fused testbed
|
||||
//
|
||||
// Y = CONV(AB, C)
|
||||
//
|
||||
// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k])
|
||||
//
|
||||
// Z[n, p, q, k] = Elementwise(T[n, p, q, k])
|
||||
//
|
||||
|
||||
template <
|
||||
typename Conv2d,
|
||||
typename ReferenceOp = Conv2dWithBroadcastReferenceOp<Conv2d>
|
||||
>
|
||||
class TestbedConv2dWithBroadcast {
|
||||
public:
|
||||
|
||||
@ -66,6 +109,8 @@ public:
|
||||
using ElementAccumulator = typename Conv2d::ElementAccumulator;
|
||||
using ElementCompute = typename Conv2d::ElementCompute;
|
||||
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
||||
using ElementZ = typename EpilogueOutputOp::ElementZ;
|
||||
using ElementT = typename EpilogueOutputOp::ElementT;
|
||||
|
||||
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
||||
|
||||
@ -80,8 +125,13 @@ public:
|
||||
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
||||
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
|
||||
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_C_reference;
|
||||
cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_computed;
|
||||
cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_reference;
|
||||
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
|
||||
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
|
||||
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_Broadcast; // Input Broadcast
|
||||
|
||||
public:
|
||||
|
||||
@ -147,18 +197,44 @@ public:
|
||||
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
||||
tensor_Broadcast.resize({
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(),
|
||||
});
|
||||
|
||||
initialize_tensor(tensor_A.host_view(), init_A, seed);
|
||||
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
|
||||
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
|
||||
initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39);
|
||||
|
||||
for (int n = 0; n < tensor_C_reference.extent().n(); ++n) {
|
||||
for (int p = 0; p < tensor_C_reference.extent().h(); ++p) {
|
||||
for (int q = 0; q < tensor_C_reference.extent().w(); ++q) {
|
||||
for (int k = 0; k < tensor_C_reference.extent().c(); ++k) {
|
||||
tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor_A.sync_device();
|
||||
tensor_B.sync_device();
|
||||
tensor_C.sync_device();
|
||||
tensor_D_computed.sync_device();
|
||||
tensor_D_reference.sync_device();
|
||||
tensor_Broadcast.sync_device();
|
||||
tensor_C_reference.sync_device();
|
||||
tensor_Z_computed.sync_device();
|
||||
tensor_Z_reference.sync_device();
|
||||
tensor_T_computed.sync_device();
|
||||
tensor_T_reference.sync_device();
|
||||
tensor_Y_reference.sync_device();
|
||||
}
|
||||
|
||||
bool sufficient() const {
|
||||
@ -215,18 +291,21 @@ public:
|
||||
|
||||
// configure the operator
|
||||
Conv2d conv2d_op;
|
||||
|
||||
typename Conv2d::Arguments conv2d_args(
|
||||
problem_size,
|
||||
tensor_A.device_ref(),
|
||||
tensor_B.device_ref(),
|
||||
tensor_C.device_ref(),
|
||||
tensor_D_computed.device_ref(),
|
||||
tensor_Z_computed.device_ref(),
|
||||
{alpha, beta},
|
||||
split_k_mode
|
||||
split_k_mode,
|
||||
tensor_Broadcast.device_data(),
|
||||
tensor_T_computed.device_data(),
|
||||
0, // This must be zero
|
||||
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()
|
||||
);
|
||||
|
||||
// find workspace requirement for parallel split-k reduction
|
||||
// initialize the kernel
|
||||
size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
|
||||
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
@ -239,22 +318,6 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
// conv2d operation with parallel split-k-mode
|
||||
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
||||
|
||||
// conv2d output is written to workspace in global memory
|
||||
conv2d_args.ref_D.reset(reinterpret_cast<ElementC*>(workspace.get()));
|
||||
// accumulate mma for each cta in k-dimension (1.0 * A * B)
|
||||
conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)};
|
||||
// update conv2d operator arguments
|
||||
status = conv2d_op.update(conv2d_args, workspace.get());
|
||||
}
|
||||
|
||||
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// run conv2d operator
|
||||
status = conv2d_op();
|
||||
|
||||
@ -269,52 +332,13 @@ public:
|
||||
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
|
||||
<< cudaGetErrorString(result);
|
||||
|
||||
tensor_D_computed.sync_host();
|
||||
tensor_T_computed.sync_host();
|
||||
tensor_Z_computed.sync_host();
|
||||
|
||||
//
|
||||
// Reference check - support caching results
|
||||
// Reference check
|
||||
//
|
||||
|
||||
CachedTestKey cached_test_key = CreateCachedConv2dWithBroadcastTestKey<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size,
|
||||
alpha,
|
||||
beta,
|
||||
tensor_A.host_view(),
|
||||
tensor_B.host_view(),
|
||||
tensor_C.host_view()
|
||||
);
|
||||
|
||||
//
|
||||
// Look for the cached key
|
||||
//
|
||||
|
||||
bool cached_result_loaded = false;
|
||||
CachedTestResult cached_test_result;
|
||||
|
||||
std::string conv2d_result_cache_name =
|
||||
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
|
||||
|
||||
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
||||
|
||||
CachedTestResultListing cached_results(conv2d_result_cache_name);
|
||||
|
||||
auto cached = cached_results.find(cached_test_key);
|
||||
|
||||
cached_result_loaded = cached.first;
|
||||
if (cached_result_loaded) {
|
||||
cached_test_result = cached.second;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cached_result_loaded) {
|
||||
|
||||
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
||||
|
||||
cutlass::reference::device::Conv2d<
|
||||
@ -322,22 +346,22 @@ public:
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
ElementAccumulator,
|
||||
LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size,
|
||||
tensor_A.device_ref(),
|
||||
tensor_B.device_ref(),
|
||||
tensor_C.device_ref(),
|
||||
tensor_D_reference.device_ref(),
|
||||
tensor_C_reference.device_ref(),
|
||||
tensor_Y_reference.device_ref(),
|
||||
alpha,
|
||||
beta);
|
||||
|
||||
// sync host (copy device data to host) for dumping error output in case of mismatches
|
||||
tensor_D_reference.sync_host();
|
||||
tensor_Y_reference.sync_host();
|
||||
|
||||
#else
|
||||
|
||||
@ -346,48 +370,50 @@ public:
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
ElementAccumulator,
|
||||
LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size,
|
||||
tensor_A.host_ref(),
|
||||
tensor_B.host_ref(),
|
||||
tensor_C.host_ref(),
|
||||
tensor_D_reference.host_ref(),
|
||||
tensor_C_reference.host_ref(),
|
||||
tensor_Y_reference.host_ref(),
|
||||
alpha,
|
||||
beta);
|
||||
|
||||
#endif
|
||||
ReferenceOp reference_op;
|
||||
|
||||
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
||||
// compute tensor Z and tensor T
|
||||
for (int n = 0; n < problem_size.N; ++n) {
|
||||
for (int p = 0; p < problem_size.P; ++p) {
|
||||
for (int q = 0; q < problem_size.Q; ++q) {
|
||||
for (int k = 0; k < problem_size.K; ++k) {
|
||||
|
||||
cached_test_result.D = TensorHash(tensor_D_reference.host_view());
|
||||
ElementZ z;
|
||||
ElementT t;
|
||||
|
||||
CachedTestResultListing cached_results(conv2d_result_cache_name);
|
||||
reference_op(z, t, tensor_Y_reference.at({n, p, q, k}), tensor_Broadcast.at({0, 0, 0, k}));
|
||||
|
||||
cached_results.append(cached_test_key, cached_test_result);
|
||||
cached_results.write(conv2d_result_cache_name);
|
||||
tensor_Z_reference.at({n, p, q, k}) = z;
|
||||
tensor_T_reference.at({n, p, q, k}) = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // if (!cached_result_loaded)
|
||||
|
||||
|
||||
uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
|
||||
|
||||
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
||||
passed = (tensor_D_hash == cached_test_result.D);
|
||||
|
||||
EXPECT_EQ(tensor_D_hash, cached_test_result.D)
|
||||
<< "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
|
||||
}
|
||||
else {
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_D_computed.host_view(),
|
||||
tensor_D_reference.host_view());
|
||||
}
|
||||
tensor_T_computed.host_view(),
|
||||
tensor_T_reference.host_view());
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_Z_computed.host_view(),
|
||||
tensor_Z_reference.host_view());
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
|
||||
@ -435,14 +461,16 @@ public:
|
||||
<< "\nA:\n" << tensor_A.host_view() << "\n"
|
||||
<< "\nB:\n" << tensor_B.host_view() << "\n"
|
||||
<< "\nC:\n" << tensor_C.host_view() << "\n"
|
||||
<< "\nD reference:\n" << tensor_D_reference.host_view() << "\n"
|
||||
<< "\nD computed:\n" << tensor_D_computed.host_view() << "\n";
|
||||
|
||||
<< "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n"
|
||||
<< "\nY reference:\n" << tensor_Y_reference.host_view() << "\n"
|
||||
<< "\nT reference:\n" << tensor_T_reference.host_view() << "\n"
|
||||
<< "\nT computed:\n" << tensor_T_computed.host_view() << "\n"
|
||||
<< "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n"
|
||||
<< "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n";
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -582,8 +610,7 @@ bool TestAllConv2dWithBroadcast(
|
||||
);
|
||||
|
||||
cutlass::conv::SplitKMode split_k_modes [] = {
|
||||
cutlass::conv::SplitKMode::kSerial,
|
||||
cutlass::conv::SplitKMode::kParallel,
|
||||
cutlass::conv::SplitKMode::kSerial
|
||||
};
|
||||
|
||||
int split_k_slices[] = {
|
||||
|
@ -69,11 +69,11 @@ struct GemmWithBroadcastReferenceOp {
|
||||
|
||||
void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) {
|
||||
|
||||
ElementCompute z_full = binary_op(gemm, bias);
|
||||
Z = ElementZ(z_full);
|
||||
|
||||
ElementCompute t_full = elementwise_op(z_full);
|
||||
ElementCompute t_full = binary_op(gemm, bias);
|
||||
T = ElementT(t_full);
|
||||
|
||||
ElementCompute z_full = elementwise_op(t_full);
|
||||
Z = ElementZ(z_full);
|
||||
}
|
||||
};
|
||||
|
||||
@ -83,9 +83,9 @@ struct GemmWithBroadcastReferenceOp {
|
||||
//
|
||||
// Y = GEMM(AB, C)
|
||||
//
|
||||
// Z[i, j] = ReductionOp(Y[i, j], Broadcast[i])
|
||||
// T[i, j] = ReductionOp(Y[i, j], Broadcast[i])
|
||||
//
|
||||
// T[i, j] = Elementwise(Z[i, j])
|
||||
// Z[i, j] = Elementwise(T[i, j])
|
||||
//
|
||||
|
||||
template <
|
||||
@ -101,7 +101,6 @@ struct TestbedGemmWithBroadcast {
|
||||
using ElementZ = typename OutputOp::ElementZ;
|
||||
using ElementT = typename OutputOp::ElementT;
|
||||
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
@ -343,7 +342,6 @@ struct TestbedGemmWithBroadcast {
|
||||
|
||||
ReferenceOp reference_op;
|
||||
|
||||
|
||||
// compute tensor Z and tensor T
|
||||
for (int m = 0; m < problem_size.m(); ++m) {
|
||||
for (int n = 0; n < problem_size.n(); ++n) {
|
||||
|
@ -24,3 +24,8 @@ cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_util
|
||||
tensor_reduce.cu
|
||||
)
|
||||
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_levels
|
||||
cutlass_test_levels.cu
|
||||
)
|
||||
|
71
test/unit/util/cutlass_test_levels.cu
Normal file
71
test/unit/util/cutlass_test_levels.cu
Normal file
@ -0,0 +1,71 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <complex>
|
||||
|
||||
#include "../common/cutlass_unit_test.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM75_CUTLASS_TEST, level_not_specified) {
|
||||
|
||||
EXPECT_TRUE(true);
|
||||
}
|
||||
|
||||
TEST(SM80_CUTLASS_TEST, level_not_specified) {
|
||||
|
||||
EXPECT_TRUE(true);
|
||||
}
|
||||
|
||||
CUTLASS_TEST_L0(SM75_CUTLASS_TEST, level0, {
|
||||
|
||||
EXPECT_TRUE(true);
|
||||
})
|
||||
|
||||
CUTLASS_TEST_L1(SM75_CUTLASS_TEST, level1, {
|
||||
|
||||
EXPECT_TRUE(true);
|
||||
})
|
||||
|
||||
CUTLASS_TEST_L2(SM75_CUTLASS_TEST, level2, {
|
||||
|
||||
EXPECT_TRUE(true);
|
||||
})
|
||||
|
||||
CUTLASS_TEST_L0(SM80_CUTLASS_TEST, level0, {
|
||||
|
||||
EXPECT_TRUE(true);
|
||||
})
|
||||
|
||||
CUTLASS_TEST_L1(SM80_CUTLASS_TEST, level1, {
|
||||
|
||||
EXPECT_TRUE(true);
|
||||
})
|
||||
|
||||
CUTLASS_TEST_L2(SM80_CUTLASS_TEST, level2, {
|
||||
|
||||
EXPECT_TRUE(true);
|
||||
})
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1056,7 +1056,6 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
|
||||
|
||||
min_cc = 75
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [32,]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
@ -1136,7 +1135,6 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
|
||||
|
||||
min_cc = 75
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [32,]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
@ -1907,7 +1905,6 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
|
||||
|
||||
min_cc = 80
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [32,]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
|
@ -102,6 +102,12 @@ bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID ele
|
||||
data_type = CUDA_R_16F;
|
||||
return true;
|
||||
|
||||
case library::NumericTypeID::kBF16:
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kTF32:
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF32:
|
||||
data_type = CUDA_R_32F;
|
||||
return true;
|
||||
|
Loading…
Reference in New Issue
Block a user