Updates to fused epilogue (#383)

* Enhancements and fixes to fused GEMM and Convolution epilogue.
* Need to explicitly list cudart as unit test library dependency.
This commit is contained in:
Andrew Kerr 2021-12-17 16:04:43 -05:00 committed by GitHub
parent 4e666e1dfd
commit ec4f7e5194
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 372 additions and 193 deletions

View File

@ -197,9 +197,19 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
message(STATUS "Enable caching of reference results in conv unit tests")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
endif()
set(CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED ON CACHE BOOL "Enable/Disable rigorous conv problem sizes in conv unit tests")
if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
message(STATUS "Enable rigorous conv problem sizes in conv unit tests")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
endif()
#
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
#

View File

@ -296,7 +296,7 @@ int main() {
return -1;
}
if (!((props.major * 10 + props.minor) >= 80)) {
if (props.major * 10 + props.minor < 80) {
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;

View File

@ -233,7 +233,7 @@ int run() {
tensor_b.device_ref(),
tensor_c_bias.device_ref(),
tensor_ref_d.device_ref(),
alpha, 0
alpha, ElementComputeEpilogue(0)
);
// Wait for kernels to finish

View File

@ -45,6 +45,7 @@ function(cutlass_example_add_executable NAME)
PRIVATE
CUTLASS
cutlass_tools_util_includes
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
)
target_include_directories(

View File

@ -109,6 +109,7 @@ struct Wmma<
FragmentB const &B,
FragmentC const &C) const {
nvcuda::wmma::mma_sync(D, A, B, C);
}
#else
@ -186,7 +187,6 @@ struct Wmma<
FragmentA const &A,
FragmentB const &B,
FragmentC const &C) const {
nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
}

View File

@ -109,6 +109,12 @@ static char const* cutlassGetStatusString(cutlass::Status status) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0
#endif
// CUDA 10.1 introduces the mma instruction
#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0

View File

@ -58,6 +58,7 @@ struct Identity {
/// ReLu operator - propagates NaNs
template <typename T>
struct ReLu {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T const & threshold, T value) const {
if (value < threshold) {
@ -76,6 +77,7 @@ struct ReLu {
template <typename T, int N>
struct ReLu<Array<T, N>> {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
Array<T, N> result;

View File

@ -201,8 +201,8 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kElementsPerAccess; ++i) {
ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
result_Z[i] = z;
result_T[i] = skip_elementwise_ ? z : elementwise_op(z);
result_T[i] = z;
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
}
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
@ -230,8 +230,8 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kElementsPerAccess; ++i) {
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
result_Z[i] = z;
result_T[i] = skip_elementwise_ ? z : elementwise_op(z);
result_T[i] = z;
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
}
NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;

View File

@ -306,6 +306,7 @@ public:
/// Debug printing
CUTLASS_DEVICE
static void print() {
#if 0
printf("BroadcastDetail {\n");
printf(
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
@ -321,6 +322,7 @@ public:
StorageShape::kCount
);
printf("};\n");
#endif
}
};

View File

@ -212,6 +212,7 @@ public:
/// Debug printing
CUTLASS_DEVICE
static void print() {
#if 0
printf("ReductionDetail {\n");
printf(
" kElementsPerAccess:%d\nkColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
@ -228,6 +229,7 @@ public:
StorageShape::kCount
);
printf("};\n");
#endif
}
};

View File

@ -363,8 +363,13 @@ tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) {
CUTLASS_HOST_DEVICE
tfloat32_t operator-(tfloat32_t const& lhs) {
float x = -reinterpret_cast<float const &>(lhs);
return *reinterpret_cast<tfloat32_t *>(&x);
union u_tff32 {
float val_f32;
tfloat32_t val_tf;
CUTLASS_HOST_DEVICE u_tff32() : val_f32(0) { }
};
union u_tff32 x; x.val_f32 = -reinterpret_cast<float const &>(lhs);
return x.val_tf;
}
CUTLASS_HOST_DEVICE

View File

@ -35,6 +35,7 @@ target_link_libraries(
cutlass_tools_util_includes
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
gtest
cudart
)
cutlass_add_library(

View File

@ -31,6 +31,8 @@
#pragma nv_diag_warning boolean_controlling_expr_is_constant
#pragma warning( disable : 4503)
#include <cstdlib>
#include <string>
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Sets flags for Unit test
@ -38,6 +40,13 @@ void FilterArchitecture();
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order
// of problem sizes run by CUTLASS unit tests
int CutlassUnitTestProblemCount();
/////////////////////////////////////////////////////////////////////////////////////////////////
// active test macro
#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__

View File

@ -91,3 +91,14 @@ void FilterArchitecture() {
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int CutlassUnitTestProblemCount() {
if(const char* problem_count = std::getenv("CUTLASS_UNIT_TEST_PROBLEM_COUNT")) {
return std::stoi(problem_count);
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -35,7 +35,6 @@
#include "conv2d_testbed.h"
////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32,
64x64_8x2_32x64x8) {

View File

@ -56,7 +56,7 @@ TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nh
cutlass::half_t,
cutlass::half_t,
8,
cutlass::epilogue::thread::GELU_taylor<float>
cutlass::epilogue::thread::ReLu<float>
>;
/// Device-level Conv2d instance

View File

@ -36,8 +36,6 @@
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"
#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 1
namespace test {
namespace conv {
namespace device {

View File

@ -99,6 +99,8 @@ public:
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
int tested_problem_count;
public:
TestbedConv2d(
@ -107,7 +109,7 @@ public:
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {
}
@ -220,7 +222,10 @@ public:
return true;
}
#if 0 //display conv2d problem size for debugging
// increment tested problem count run by the testbed
tested_problem_count++;
#if 0 // display conv2d problem size for debugging
std::cout << problem_size << std::endl
<< "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
<< "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
@ -537,20 +542,32 @@ bool TestAllConv2d(
// Vector of conv2d problem sizes to avoid duplicate runs
Conv2dProblemVector conv_tested_sizes;
Conv2dProblemVector const *problem_vectors[] = {
&conv_test_sizes, // run user specified sizes
&conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
//&conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
// Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes)
std::vector<Conv2dProblemVector> problem_vectors = {
conv_test_sizes, // run user specified sizes
conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
//conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
&conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
#endif
};
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
for (Conv2dProblemVector const * problem_vector : problem_vectors) {
// Flatten 2D problem_vectors into a 1D problem_sizes
std::vector<cutlass::conv::Conv2dProblemSize> problem_sizes;
for (auto problem_vector : problem_vectors) {
for(auto conv_problem : problem_vector) {
problem_sizes.push_back(conv_problem);
}
}
// Run conv testbed on default convolution sizes
for(auto conv_problem : *problem_vector) {
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient)
// run the most rigorous problem size first
if (CutlassUnitTestProblemCount()) {
std::reverse(problem_sizes.begin(), problem_sizes.end());
}
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
for(auto conv_problem : problem_sizes) {
// Skip blacklist and avoid duplicate problem sizes
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
@ -607,9 +624,15 @@ bool TestAllConv2d(
if (!passed) {
return false;
}
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts
if (CutlassUnitTestProblemCount() &&
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
return true;
}
}
// CUTLASS DGRAD's *strided* specialization does not support split-k mode
if ((ImplicitGemm::kConvolutionalOperator ==
cutlass::conv::Operator::kDgrad) &&
@ -677,6 +700,12 @@ bool TestAllConv2d(
if (!passed) {
return false;
}
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts
if (CutlassUnitTestProblemCount() &&
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
return true;
}
}
}
}

View File

@ -23,7 +23,11 @@
*
**************************************************************************************************/
/*! \file
\brief Implicit GEMM testbed
\brief Implicit GEMM for fused epilogue broadcast testbed
Parallel split-k is not tested because we can just use regular conv kernel
when we need to use parallel-splitk. Broadcast can happen in the reduction
kernel.
*/
#pragma once
@ -53,7 +57,46 @@ namespace test {
namespace conv {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Conv2d>
struct Conv2dWithBroadcastReferenceOp {
using OutputOp = typename Conv2d::EpilogueOutputOp;
using ElementCompute = typename OutputOp::ElementCompute;
using ElementZ = typename OutputOp::ElementZ;
using ElementT = typename OutputOp::ElementT;
typename OutputOp::BinaryOp binary_op;
typename OutputOp::ElementwiseOp elementwise_op;
Conv2dWithBroadcastReferenceOp() { }
void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) {
ElementCompute t_full = binary_op(conv2d, bias);
T = ElementT(t_full);
ElementCompute z_full = elementwise_op(t_full);
Z = ElementZ(z_full);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Fused testbed
//
// Y = CONV(AB, C)
//
// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k])
//
// Z[n, p, q, k] = Elementwise(T[n, p, q, k])
//
template <
typename Conv2d,
typename ReferenceOp = Conv2dWithBroadcastReferenceOp<Conv2d>
>
class TestbedConv2dWithBroadcast {
public:
@ -66,6 +109,8 @@ public:
using ElementAccumulator = typename Conv2d::ElementAccumulator;
using ElementCompute = typename Conv2d::ElementCompute;
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
using ElementZ = typename EpilogueOutputOp::ElementZ;
using ElementT = typename EpilogueOutputOp::ElementT;
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
@ -80,8 +125,13 @@ public:
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_C_reference;
cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_computed;
cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_reference;
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
cutlass::HostTensor<ElementC, LayoutC> tensor_Broadcast; // Input Broadcast
public:
@ -147,18 +197,44 @@ public:
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_Broadcast.resize({
1,
1,
1,
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(),
});
initialize_tensor(tensor_A.host_view(), init_A, seed);
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39);
for (int n = 0; n < tensor_C_reference.extent().n(); ++n) {
for (int p = 0; p < tensor_C_reference.extent().h(); ++p) {
for (int q = 0; q < tensor_C_reference.extent().w(); ++q) {
for (int k = 0; k < tensor_C_reference.extent().c(); ++k) {
tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k}));
}
}
}
}
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D_computed.sync_device();
tensor_D_reference.sync_device();
tensor_Broadcast.sync_device();
tensor_C_reference.sync_device();
tensor_Z_computed.sync_device();
tensor_Z_reference.sync_device();
tensor_T_computed.sync_device();
tensor_T_reference.sync_device();
tensor_Y_reference.sync_device();
}
bool sufficient() const {
@ -215,18 +291,21 @@ public:
// configure the operator
Conv2d conv2d_op;
typename Conv2d::Arguments conv2d_args(
problem_size,
tensor_A.device_ref(),
tensor_B.device_ref(),
tensor_C.device_ref(),
tensor_D_computed.device_ref(),
tensor_Z_computed.device_ref(),
{alpha, beta},
split_k_mode
split_k_mode,
tensor_Broadcast.device_data(),
tensor_T_computed.device_data(),
0, // This must be zero
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()
);
// find workspace requirement for parallel split-k reduction
// initialize the kernel
size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
@ -239,22 +318,6 @@ public:
return true;
}
// conv2d operation with parallel split-k-mode
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
// conv2d output is written to workspace in global memory
conv2d_args.ref_D.reset(reinterpret_cast<ElementC*>(workspace.get()));
// accumulate mma for each cta in k-dimension (1.0 * A * B)
conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)};
// update conv2d operator arguments
status = conv2d_op.update(conv2d_args, workspace.get());
}
EXPECT_TRUE(status == cutlass::Status::kSuccess);
if (status != cutlass::Status::kSuccess) {
return false;
}
// run conv2d operator
status = conv2d_op();
@ -269,52 +332,13 @@ public:
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
<< cudaGetErrorString(result);
tensor_D_computed.sync_host();
tensor_T_computed.sync_host();
tensor_Z_computed.sync_host();
//
// Reference check - support caching results
// Reference check
//
CachedTestKey cached_test_key = CreateCachedConv2dWithBroadcastTestKey<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator,
ElementCompute
>(
kConvolutionalOperator,
problem_size,
alpha,
beta,
tensor_A.host_view(),
tensor_B.host_view(),
tensor_C.host_view()
);
//
// Look for the cached key
//
bool cached_result_loaded = false;
CachedTestResult cached_test_result;
std::string conv2d_result_cache_name =
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
CachedTestResultListing cached_results(conv2d_result_cache_name);
auto cached = cached_results.find(cached_test_key);
cached_result_loaded = cached.first;
if (cached_result_loaded) {
cached_test_result = cached.second;
}
}
if (!cached_result_loaded) {
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
cutlass::reference::device::Conv2d<
@ -322,22 +346,22 @@ public:
LayoutA,
ElementB,
LayoutB,
ElementC,
ElementAccumulator,
LayoutC,
ElementCompute,
ElementAccumulator,
ElementAccumulator
>(
kConvolutionalOperator,
problem_size,
tensor_A.device_ref(),
tensor_B.device_ref(),
tensor_C.device_ref(),
tensor_D_reference.device_ref(),
tensor_C_reference.device_ref(),
tensor_Y_reference.device_ref(),
alpha,
beta);
// sync host (copy device data to host) for dumping error output in case of mismatches
tensor_D_reference.sync_host();
tensor_Y_reference.sync_host();
#else
@ -346,48 +370,50 @@ public:
LayoutA,
ElementB,
LayoutB,
ElementC,
ElementAccumulator,
LayoutC,
ElementCompute,
ElementAccumulator,
ElementAccumulator
>(
kConvolutionalOperator,
problem_size,
tensor_A.host_ref(),
tensor_B.host_ref(),
tensor_C.host_ref(),
tensor_D_reference.host_ref(),
tensor_C_reference.host_ref(),
tensor_Y_reference.host_ref(),
alpha,
beta);
#endif
ReferenceOp reference_op;
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
// compute tensor Z and tensor T
for (int n = 0; n < problem_size.N; ++n) {
for (int p = 0; p < problem_size.P; ++p) {
for (int q = 0; q < problem_size.Q; ++q) {
for (int k = 0; k < problem_size.K; ++k) {
cached_test_result.D = TensorHash(tensor_D_reference.host_view());
ElementZ z;
ElementT t;
CachedTestResultListing cached_results(conv2d_result_cache_name);
reference_op(z, t, tensor_Y_reference.at({n, p, q, k}), tensor_Broadcast.at({0, 0, 0, k}));
cached_results.append(cached_test_key, cached_test_result);
cached_results.write(conv2d_result_cache_name);
tensor_Z_reference.at({n, p, q, k}) = z;
tensor_T_reference.at({n, p, q, k}) = t;
}
}
}
} // if (!cached_result_loaded)
uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
passed = (tensor_D_hash == cached_test_result.D);
EXPECT_EQ(tensor_D_hash, cached_test_result.D)
<< "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
}
else {
passed = cutlass::reference::host::TensorEquals(
tensor_D_computed.host_view(),
tensor_D_reference.host_view());
}
tensor_T_computed.host_view(),
tensor_T_reference.host_view());
EXPECT_TRUE(passed);
passed = cutlass::reference::host::TensorEquals(
tensor_Z_computed.host_view(),
tensor_Z_reference.host_view());
EXPECT_TRUE(passed);
@ -435,14 +461,16 @@ public:
<< "\nA:\n" << tensor_A.host_view() << "\n"
<< "\nB:\n" << tensor_B.host_view() << "\n"
<< "\nC:\n" << tensor_C.host_view() << "\n"
<< "\nD reference:\n" << tensor_D_reference.host_view() << "\n"
<< "\nD computed:\n" << tensor_D_computed.host_view() << "\n";
<< "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n"
<< "\nY reference:\n" << tensor_Y_reference.host_view() << "\n"
<< "\nT reference:\n" << tensor_T_reference.host_view() << "\n"
<< "\nT computed:\n" << tensor_T_computed.host_view() << "\n"
<< "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n"
<< "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n";
}
return passed;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -582,8 +610,7 @@ bool TestAllConv2dWithBroadcast(
);
cutlass::conv::SplitKMode split_k_modes [] = {
cutlass::conv::SplitKMode::kSerial,
cutlass::conv::SplitKMode::kParallel,
cutlass::conv::SplitKMode::kSerial
};
int split_k_slices[] = {

View File

@ -69,11 +69,11 @@ struct GemmWithBroadcastReferenceOp {
void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) {
ElementCompute z_full = binary_op(gemm, bias);
Z = ElementZ(z_full);
ElementCompute t_full = elementwise_op(z_full);
ElementCompute t_full = binary_op(gemm, bias);
T = ElementT(t_full);
ElementCompute z_full = elementwise_op(t_full);
Z = ElementZ(z_full);
}
};
@ -83,9 +83,9 @@ struct GemmWithBroadcastReferenceOp {
//
// Y = GEMM(AB, C)
//
// Z[i, j] = ReductionOp(Y[i, j], Broadcast[i])
// T[i, j] = ReductionOp(Y[i, j], Broadcast[i])
//
// T[i, j] = Elementwise(Z[i, j])
// Z[i, j] = Elementwise(T[i, j])
//
template <
@ -101,7 +101,6 @@ struct TestbedGemmWithBroadcast {
using ElementZ = typename OutputOp::ElementZ;
using ElementT = typename OutputOp::ElementT;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
@ -343,7 +342,6 @@ struct TestbedGemmWithBroadcast {
ReferenceOp reference_op;
// compute tensor Z and tensor T
for (int m = 0; m < problem_size.m(); ++m) {
for (int n = 0; n < problem_size.n(); ++n) {

View File

@ -24,3 +24,8 @@ cutlass_test_unit_add_executable(
cutlass_test_unit_util
tensor_reduce.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_levels
cutlass_test_levels.cu
)

View File

@ -0,0 +1,71 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <complex>
#include "../common/cutlass_unit_test.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(SM75_CUTLASS_TEST, level_not_specified) {
EXPECT_TRUE(true);
}
TEST(SM80_CUTLASS_TEST, level_not_specified) {
EXPECT_TRUE(true);
}
CUTLASS_TEST_L0(SM75_CUTLASS_TEST, level0, {
EXPECT_TRUE(true);
})
CUTLASS_TEST_L1(SM75_CUTLASS_TEST, level1, {
EXPECT_TRUE(true);
})
CUTLASS_TEST_L2(SM75_CUTLASS_TEST, level2, {
EXPECT_TRUE(true);
})
CUTLASS_TEST_L0(SM80_CUTLASS_TEST, level0, {
EXPECT_TRUE(true);
})
CUTLASS_TEST_L1(SM80_CUTLASS_TEST, level1, {
EXPECT_TRUE(true);
})
CUTLASS_TEST_L2(SM80_CUTLASS_TEST, level2, {
EXPECT_TRUE(true);
})
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1056,7 +1056,6 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args):
min_cc = 75
max_cc = 1024
alignment_constraints = [32,]
for math_inst in math_instructions:
@ -1136,7 +1135,6 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args):
min_cc = 75
max_cc = 1024
alignment_constraints = [32,]
for math_inst in math_instructions:
@ -1907,7 +1905,6 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args):
min_cc = 80
max_cc = 1024
alignment_constraints = [32,]
for math_inst in math_instructions:

View File

@ -102,6 +102,12 @@ bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID ele
data_type = CUDA_R_16F;
return true;
case library::NumericTypeID::kBF16:
break;
case library::NumericTypeID::kTF32:
break;
case library::NumericTypeID::kF32:
data_type = CUDA_R_32F;
return true;