CUTLASS 2.10 bug fixes and minor updates. (#626)
This commit is contained in:
parent
2cc2c7ba1f
commit
fc9ebc645b
@ -80,9 +80,9 @@ public:
|
||||
typedef value_type *pointer;
|
||||
typedef value_type const * const_pointer;
|
||||
|
||||
using Array = Array<T, N>;
|
||||
using reference = typename Array::reference;
|
||||
using const_reference = typename Array::const_reference;
|
||||
using ArrayType = Array<T, N>;
|
||||
using reference = typename ArrayType::reference;
|
||||
using const_reference = typename ArrayType::const_reference;
|
||||
|
||||
public:
|
||||
|
||||
|
@ -633,7 +633,7 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
|
||||
clear_mask(v_idx, filter_k_ >= problem_size.K);
|
||||
clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K);
|
||||
}
|
||||
|
||||
set_iteration_index(0);
|
||||
|
@ -91,11 +91,11 @@ public:
|
||||
Convert(Params const ¶ms = Params()) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition, int k_partition_count) {
|
||||
|
||||
|
||||
}
|
||||
|
||||
/// Returns true if source is needed based on state of runtime arguments
|
||||
|
@ -671,7 +671,7 @@ public:
|
||||
state_[1] = 0;
|
||||
++state_[2];
|
||||
byte_pointer_ += params_.advance_cluster;
|
||||
store_byte_pointer_ += params_.advance_group;
|
||||
store_byte_pointer_ += params_.advance_cluster;
|
||||
|
||||
thread_start_row_ += ThreadMap::Count::kGroup *
|
||||
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
|
||||
@ -679,7 +679,7 @@ public:
|
||||
if (state_[2] == ThreadMap::Count::kCluster) {
|
||||
state_[2] = 0;
|
||||
byte_pointer_ += params_.advance_tile;
|
||||
store_byte_pointer_ += params_.advance_group;
|
||||
store_byte_pointer_ += params_.advance_tile;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -240,4 +240,59 @@ TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2,
|
||||
128x64_64x3_64x32x64) {
|
||||
|
||||
/// Conv operation element types for the Gemm equivalent (ImplicitGemm)
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
/// Device-level Conv2d instance
|
||||
using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad<
|
||||
ElementA, cutlass::layout::TensorNHWC,
|
||||
ElementB, cutlass::layout::TensorNHWC,
|
||||
ElementC, cutlass::layout::TensorNHWC,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 64, 32>,
|
||||
cutlass::gemm::GemmShape<64, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
2,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::conv::IteratorAlgorithm::kOptimized,
|
||||
cutlass::conv::StrideSupport::kUnity,
|
||||
2,
|
||||
2
|
||||
>::Kernel;
|
||||
|
||||
using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
|
||||
|
||||
test::conv::device::Conv2dProblemVector problem_size_list;
|
||||
|
||||
// run specific problem size in the unit test first
|
||||
problem_size_list.push_back(cutlass::conv::Conv2dProblemSize(
|
||||
{35, 100, 50, 64}, // input size (NHWC)
|
||||
{22, 1, 1, 64}, // filter size (KRSC)
|
||||
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
||||
{1, 1}, // stride (stride_h, stride_w)
|
||||
{1, 1} // dilation (dilation_h, dilation_w)
|
||||
));
|
||||
|
||||
/// Run all unit test sizes with device-level Conv2d instance
|
||||
EXPECT_TRUE(test::conv::device::TestAllConv2d<Conv2dDgrad>(problem_size_list));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED
|
||||
|
Loading…
Reference in New Issue
Block a user