Removed trivial copy constructors on parameter classes to enable devi… (#366)
* Removed trivial copy constructors on parameter classes to enable device-side launch of CUTLASS kernels * Added SFINAE to the `TensorRef(NonConstTensorRef const&)` constructor to avoid making it a copy-constructor for device code * std => platform * fix affine2 * really fix affine2 Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
e96f00586c
commit
96a11a1ef3
@ -94,7 +94,7 @@ using SmArch = cutlass::arch::Sm80;
|
|||||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // Threadblock tile shape
|
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; // Threadblock tile shape
|
||||||
|
|
||||||
// This code section describes tile size a warp will compute
|
// This code section describes tile size a warp will compute
|
||||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape
|
using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // Warp tile shape
|
||||||
|
|
||||||
// This code section describes the size of MMA op
|
// This code section describes the size of MMA op
|
||||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TensorCore instruction shape
|
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; // TensorCore instruction shape
|
||||||
@ -111,6 +111,7 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
|||||||
1, // The number of elements per memory
|
1, // The number of elements per memory
|
||||||
// access has. It has to be 1 for
|
// access has. It has to be 1 for
|
||||||
// affine2.
|
// affine2.
|
||||||
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue>;
|
ElementComputeEpilogue>;
|
||||||
|
|
||||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||||
@ -226,7 +227,7 @@ int run() {
|
|||||||
tensor_b.device_ref().data(), // <- reference to matrix B on device
|
tensor_b.device_ref().data(), // <- reference to matrix B on device
|
||||||
tensor_c.device_ref().data(), // <- reference to matrix C on device
|
tensor_c.device_ref().data(), // <- reference to matrix C on device
|
||||||
tensor_d.device_ref().data(), // <- reference to matrix D on device
|
tensor_d.device_ref().data(), // <- reference to matrix D on device
|
||||||
tensor_a.layout().capacity(problem_size.mn()),
|
tensor_a.layout().capacity(problem_size.mk()),
|
||||||
tensor_b.layout().capacity(problem_size.kn()),
|
tensor_b.layout().capacity(problem_size.kn()),
|
||||||
tensor_c.layout().capacity(problem_size.mn()),
|
tensor_c.layout().capacity(problem_size.mn()),
|
||||||
tensor_d.layout().capacity(problem_size.mn()),
|
tensor_d.layout().capacity(problem_size.mn()),
|
||||||
@ -302,7 +303,7 @@ int run() {
|
|||||||
|
|
||||||
CUTLASS_CHECK(status);
|
CUTLASS_CHECK(status);
|
||||||
|
|
||||||
return 0;
|
return (pass ? 0 : -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char const **args) {
|
int main(int argc, char const **args) {
|
||||||
|
@ -94,14 +94,6 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Copy constructor
|
|
||||||
CUTLASS_HOST_DEVICE
|
|
||||||
Coord(Coord<kRank, Index, LongIndex> const &coord) {
|
|
||||||
for (int i = 0; i < kRank; ++i) {
|
|
||||||
idx[i] = coord[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a slice of the Coord which may be larger or smaller in rank
|
/// Returns a slice of the Coord which may be larger or smaller in rank
|
||||||
/// than this.
|
/// than this.
|
||||||
template <int Slice>
|
template <int Slice>
|
||||||
|
@ -162,6 +162,10 @@ public:
|
|||||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d) {
|
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d) {
|
||||||
|
lda = 0;
|
||||||
|
ldb = 0;
|
||||||
|
ldc = 0;
|
||||||
|
ldd = 0;
|
||||||
|
|
||||||
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size);
|
||||||
}
|
}
|
||||||
|
@ -219,9 +219,12 @@ class TensorRef {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Converting constructor from TensorRef to non-constant data.
|
/// Converting constructor from TensorRef to non-constant data.
|
||||||
|
template<typename _Magic = int>
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
TensorRef(
|
TensorRef(
|
||||||
NonConstTensorRef const &ref ///< TensorRef to non-const data
|
NonConstTensorRef const &ref, ///< TensorRef to non-const data
|
||||||
|
///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
|
||||||
|
_Magic magic = (typename platform::enable_if< ! platform::is_same<NonConstTensorRef, TensorRef<Element_, Layout_> >::value, _Magic>::type)0
|
||||||
):
|
):
|
||||||
ptr_(ref.data()), layout_(ref.layout()) { }
|
ptr_(ref.data()), layout_(ref.layout()) { }
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user