Removed trivial copy constructors on parameter classes to enable devi… (#366)

* Removed trivial copy constructors on parameter classes to enable device-side launch of CUTLASS kernels

* Added SFINAE to the `TensorRef(NonConstTensorRef const&)` constructor to avoid making it a copy-constructor for device code

* std => platform

* fix affine2

* really fix affine2

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
HouQiming 2022-03-01 10:34:02 +08:00 committed by GitHub
parent e96f00586c
commit 96a11a1ef3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 13 deletions

View File

@ -94,7 +94,7 @@ using SmArch = cutlass::arch::Sm80;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape
using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TensorCore instruction shape
@ -110,7 +110,8 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
1, // The number of elements per memory
// access has. It has to be 1 for
// affine2.
// affine2.
ElementAccumulator,
ElementComputeEpilogue>;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmUniversal<
@ -226,7 +227,7 @@ int run() {
tensor_b.device_ref().data(), // <- reference to matrix B on device
tensor_c.device_ref().data(), // <- reference to matrix C on device
tensor_d.device_ref().data(), // <- reference to matrix D on device
tensor_a.layout().capacity(problem_size.mn()),
tensor_a.layout().capacity(problem_size.mk()),
tensor_b.layout().capacity(problem_size.kn()),
tensor_c.layout().capacity(problem_size.mn()),
tensor_d.layout().capacity(problem_size.mn()),
@ -302,7 +303,7 @@ int run() {
CUTLASS_CHECK(status);
return 0;
return (pass ? 0 : -1);
}
int main(int argc, char const **args) {

View File

@ -94,14 +94,6 @@ public:
}
}
/// Copy constructor
CUTLASS_HOST_DEVICE
Coord(Coord<kRank, Index, LongIndex> const &coord) {
for (int i = 0; i < kRank; ++i) {
idx[i] = coord[i];
}
}
/// Returns a slice of the Coord which may be larger or smaller in rank
/// than this.
template <int Slice>

View File

@ -162,6 +162,10 @@ public:
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d) {
lda = 0;
ldb = 0;
ldc = 0;
ldd = 0;
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
}

View File

@ -219,9 +219,12 @@ class TensorRef {
}
/// Converting constructor from TensorRef to non-constant data.
template<typename _Magic = int>
CUTLASS_HOST_DEVICE
TensorRef(
NonConstTensorRef const &ref ///< TensorRef to non-const data
NonConstTensorRef const &ref, ///< TensorRef to non-const data
///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
_Magic magic = (typename platform::enable_if< ! platform::is_same<NonConstTensorRef, TensorRef<Element_, Layout_> >::value, _Magic>::type)0
):
ptr_(ref.data()), layout_(ref.layout()) { }