CUTLASS 2.10 bug fixes and minor updates. (#626)

This commit is contained in:
Andrew Kerr 2022-09-15 16:20:33 -04:00 committed by GitHub
parent 2cc2c7ba1f
commit fc9ebc645b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 8 deletions

View File

@ -80,9 +80,9 @@ public:
typedef value_type *pointer;
typedef value_type const * const_pointer;
using Array = Array<T, N>;
using reference = typename Array::reference;
using const_reference = typename Array::const_reference;
using ArrayType = Array<T, N>;
using reference = typename ArrayType::reference;
using const_reference = typename ArrayType::const_reference;
public:

View File

@ -633,7 +633,7 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
clear_mask(v_idx, filter_k_ >= problem_size.K);
clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K);
}
set_iteration_index(0);

View File

@ -671,7 +671,7 @@ public:
state_[1] = 0;
++state_[2];
byte_pointer_ += params_.advance_cluster;
store_byte_pointer_ += params_.advance_group;
store_byte_pointer_ += params_.advance_cluster;
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
@ -679,7 +679,7 @@ public:
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
store_byte_pointer_ += params_.advance_group;
store_byte_pointer_ += params_.advance_tile;
}
}
}

View File

@ -240,4 +240,59 @@ TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2,
128x64_64x3_64x32x64) {
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t;
/// Device-level Conv2d instance
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
ElementA, cutlass::layout::TensorNHWC,
ElementB, cutlass::layout::TensorNHWC,
ElementC, cutlass::layout::TensorNHWC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementC,
2,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kOptimized,
cutlass::conv::StrideSupport::kUnity,
2,
2
>::Kernel;
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
test::conv::device::Conv2dProblemVector problem_size_list;
// run specific problem size in the unit test first
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
{35, 100, 50, 64}, // input size (NHWC)
{22, 1, 1, 64}, // filter size (KRSC)
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
{1, 1}, // stride (stride_h, stride_w)
{1, 1} // dilation (dilation_h, dilation_w)
));
/// Run all unit test sizes with device-level Conv2d instance
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
}
////////////////////////////////////////////////////////////////////////////////
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED