Removed trivial copy constructors on parameter classes to enable devi… (#366)
* Removed trivial copy constructors on parameter classes to enable device-side launch of CUTLASS kernels * Added SFINAE to the `TensorRef(NonConstTensorRef const&)` constructor to avoid making it a copy-constructor for device code * std => platform * fix affine2 * really fix affine2 Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
e96f00586c
commit
96a11a1ef3
@ -94,7 +94,7 @@ using SmArch = cutlass::arch::Sm80;
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // Threadblock tile shape
|
||||
|
||||
// This code section describes tile size a warp will compute
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // Warp tile shape
|
||||
|
||||
// This code section describes the size of MMA op
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TensorCore instruction shape
|
||||
@ -110,7 +110,8 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, // Data type of output matrix.
|
||||
1, // The number of elements per memory
|
||||
// access has. It has to be 1 for
|
||||
// affine2.
|
||||
// affine2.
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue>;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||
@ -226,7 +227,7 @@ int run() {
|
||||
tensor_b.device_ref().data(), // <- reference to matrix B on device
|
||||
tensor_c.device_ref().data(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref().data(), // <- reference to matrix D on device
|
||||
tensor_a.layout().capacity(problem_size.mn()),
|
||||
tensor_a.layout().capacity(problem_size.mk()),
|
||||
tensor_b.layout().capacity(problem_size.kn()),
|
||||
tensor_c.layout().capacity(problem_size.mn()),
|
||||
tensor_d.layout().capacity(problem_size.mn()),
|
||||
@ -302,7 +303,7 @@ int run() {
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
return 0;
|
||||
return (pass ? 0 : -1);
|
||||
}
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
@ -94,14 +94,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(Coord<kRank, Index, LongIndex> const &coord) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = coord[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a slice of the Coord which may be larger or smaller in rank
|
||||
/// than this.
|
||||
template <int Slice>
|
||||
|
@ -162,6 +162,10 @@ public:
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d) {
|
||||
lda = 0;
|
||||
ldb = 0;
|
||||
ldc = 0;
|
||||
ldd = 0;
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
|
@ -219,9 +219,12 @@ class TensorRef {
|
||||
}
|
||||
|
||||
/// Converting constructor from TensorRef to non-constant data.
|
||||
template<typename _Magic = int>
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(
|
||||
NonConstTensorRef const &ref ///< TensorRef to non-const data
|
||||
NonConstTensorRef const &ref, ///< TensorRef to non-const data
|
||||
///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
|
||||
_Magic magic = (typename platform::enable_if< ! platform::is_same<NonConstTensorRef, TensorRef<Element_, Layout_> >::value, _Magic>::type)0
|
||||
):
|
||||
ptr_(ref.data()), layout_(ref.layout()) { }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user