diff --git a/include/cutlass/aligned_buffer.h b/include/cutlass/aligned_buffer.h index f869d388..751e72a4 100644 --- a/include/cutlass/aligned_buffer.h +++ b/include/cutlass/aligned_buffer.h @@ -80,9 +80,9 @@ public: typedef value_type *pointer; typedef value_type const * const_pointer; - using Array = Array; - using reference = typename Array::reference; - using const_reference = typename Array::const_reference; + using ArrayType = Array; + using reference = typename ArrayType::reference; + using const_reference = typename ArrayType::const_reference; public: diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h index a825f4ce..ccb5cb07 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -633,7 +633,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_k_ >= problem_size.K); + clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K); } set_iteration_index(0); diff --git a/include/cutlass/epilogue/thread/conversion_op.h b/include/cutlass/epilogue/thread/conversion_op.h index ba3738ed..438047bf 100644 --- a/include/cutlass/epilogue/thread/conversion_op.h +++ b/include/cutlass/epilogue/thread/conversion_op.h @@ -91,11 +91,11 @@ public: Convert(Params const ¶ms = Params()) { } - + /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE void set_k_partition(int k_partition, int k_partition_count) { - + } /// Returns true if source is needed based on state of runtime arguments diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index d70a1989..f7e5c2fd 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -671,7 +671,7 @@ public: state_[1] = 0; ++state_[2]; byte_pointer_ += params_.advance_cluster; - store_byte_pointer_ += params_.advance_group; + store_byte_pointer_ += params_.advance_cluster; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; @@ -679,7 +679,7 @@ public: if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; byte_pointer_ += params_.advance_tile; - store_byte_pointer_ += params_.advance_group; + store_byte_pointer_ += params_.advance_tile; } } } diff --git a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu index b7e9b94b..29649282 100644 --- a/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +++ b/test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu @@ -240,4 +240,59 @@ TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ten //////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2, + 128x64_64x3_64x32x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 2, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity, + 2, + 2 + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {35, 100, 50, 64}, // input size (NHWC) + {22, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED