add verification of the reduction tensor (#489)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2022-05-06 13:24:51 -04:00 committed by GitHub
parent ddd8f9cf41
commit 6023038bae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -37,6 +37,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "cutlass/reduction/device/tensor_reduce.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
@ -88,8 +89,9 @@ public:
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
cutlass::HostTensor<ElementAccumulator, cutlass::layout::RowMajor> tensor_Reduction;
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Reduction;
cutlass::HostTensor<ElementT, cutlass::layout::RowMajor> tensor_Tensor;
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Final_Reduction;
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
@ -115,23 +117,8 @@ public:
if (dist_kind == cutlass::Distribution::Uniform) {
int scope;
int bits = cutlass::sizeof_bits<Element>::value;
int scope = 2;
if (bits <= 8) {
scope = 2;
}
else if (bits == 16) {
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
scope = 3;
}
else {
scope = 5;
}
}
else {
scope = 8;
}
cutlass::reference::host::TensorFillRandomUniform(
view, seed, scope, -scope, 0);
}
@ -159,8 +146,17 @@ public:
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
tensor_Reduction.resize({
(problem_size.N * problem_size.P * problem_size.Q),
(problem_size.K - 1 + Conv2d::ThreadblockShape::kN) / Conv2d::ThreadblockShape::kN
1,
1,
(problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM,
(problem_size.K)
});
tensor_Final_Reduction.resize({
1,
1,
1,
(problem_size.K)
});
tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K});
@ -291,52 +287,36 @@ public:
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
<< cudaGetErrorString(result);
// Final reduction over the partial reduction tensor
using Functor = cutlass::plus<ElementAccumulator>;
using TensorReduction = cutlass::reduction::device::TensorReduction<
ElementAccumulator,
ElementAccumulator,
LayoutC,
Functor,
8,
ElementAccumulator
>;
TensorReduction reduction(tensor_Reduction.extent(), 2);
cutlass::DeviceAllocation<uint8_t> reduction_device_workspace(reduction.workspace_size());
status = reduction.reduce(
tensor_Final_Reduction.device_ref(),
tensor_Reduction.device_ref(),
reduction_device_workspace.get(),
ElementAccumulator());
EXPECT_EQ(status, cutlass::Status::kSuccess);
EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess);
//
// Reference check
//
tensor_D_computed.sync_host();
//
// Reference check - support caching results
//
CachedTestKey cached_test_key = CreateCachedConv2dWithReductionTestKey<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator,
ElementCompute
>(
kConvolutionalOperator,
problem_size,
alpha,
beta,
tensor_A.host_view(),
tensor_B.host_view(),
tensor_C.host_view()
);
//
// Look for the cached key
//
bool cached_result_loaded = false;
CachedTestResult cached_test_result;
std::string conv2d_result_cache_name =
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
CachedTestResultListing cached_results(conv2d_result_cache_name);
auto cached = cached_results.find(cached_test_key);
cached_result_loaded = cached.first;
if (cached_result_loaded) {
cached_test_result = cached.second;
}
}
if (!cached_result_loaded) {
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
cutlass::reference::device::Conv2d<
@ -384,32 +364,44 @@ public:
#endif
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
passed = cutlass::reference::host::TensorEquals(
tensor_D_computed.host_view(),
tensor_D_reference.host_view());
cached_test_result.D = TensorHash(tensor_D_reference.host_view());
EXPECT_TRUE(passed);
CachedTestResultListing cached_results(conv2d_result_cache_name);
//
// Reference check on reduction results
//
cached_results.append(cached_test_key, cached_test_result);
cached_results.write(conv2d_result_cache_name);
tensor_Reduction.sync_host();
tensor_Final_Reduction.sync_host();
// compute backwards for reduction results
cutlass::HostTensor<ElementAccumulator, LayoutC> reference_Reduction;
reference_Reduction.resize({
1,
1,
1,
(problem_size.K)
});
for (int k = 0; k < problem_size.K; ++k) {
ElementAccumulator reduced_value = ElementAccumulator();
for (int n = 0; n < problem_size.N; ++n) {
for (int p = 0; p < problem_size.P; ++p) {
for (int q = 0; q < problem_size.Q; ++q) {
reduced_value += tensor_D_reference.at({n, p, q, k});
}
}
}
} // if (!cached_result_loaded)
uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
passed = (tensor_D_hash == cached_test_result.D);
EXPECT_EQ(tensor_D_hash, cached_test_result.D)
<< "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
reference_Reduction.at({0, 0, 0, k}) = reduced_value;
}
else {
passed = cutlass::reference::host::TensorEquals(
tensor_D_computed.host_view(),
tensor_D_reference.host_view());
}
passed = cutlass::reference::host::TensorEquals(
tensor_Final_Reduction.host_view(),
reference_Reduction.host_view()
);
EXPECT_TRUE(passed);
@ -458,13 +450,13 @@ public:
<< "\nB:\n" << tensor_B.host_view() << "\n"
<< "\nC:\n" << tensor_C.host_view() << "\n"
<< "\nD reference:\n" << tensor_D_reference.host_view() << "\n"
<< "\nD computed:\n" << tensor_D_computed.host_view() << "\n";
<< "\nD computed:\n" << tensor_D_computed.host_view() << "\n"
<< "\nreduction reference:\n" << reference_Reduction.host_view() << "\n"
<< "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n";
}
return passed;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -603,9 +595,9 @@ bool TestAllConv2dWithReduction(
{1, 1} // dilation (dilation_h, dilation_w)
);
// Parallel SplitK is not tested.
cutlass::conv::SplitKMode split_k_modes [] = {
cutlass::conv::SplitKMode::kSerial,
cutlass::conv::SplitKMode::kParallel,
};
int split_k_slices[] = {