2.9 fixes for nvrtc (#480)

* Use platform::is_same instead of std::is_same

* Don't hide cuComplex include from nvrtc

* Typo fixed

* Remove comment rename
This commit is contained in:
Stepan Tezyunichev 2022-04-29 16:06:52 +03:00 committed by GitHub
parent 21c1fa3849
commit 86ce09aed1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 36 additions and 35 deletions

View File

@ -30,11 +30,12 @@
**************************************************************************************************/
#pragma once
#include <cuComplex.h>
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <cstdint>
#include <cuComplex.h>
#endif
#include "cutlass/cutlass.h"
@ -435,10 +436,10 @@ CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
/// Indentity transform for non-complex types
template <typename T>
CUTLASS_HOST_DEVICE T conj(T const &z) {
static_assert( !std::is_same<T, cuComplex>::value &&
!std::is_same<T, cuDoubleComplex>::value &&
!std::is_same<T, cutlass::complex<double>>::value &&
!std::is_same<T, cutlass::complex<float>>::value, "May not be a complex data type");
static_assert( !platform::is_same<T, cuComplex>::value &&
!platform::is_same<T, cuDoubleComplex>::value &&
!platform::is_same<T, cutlass::complex<double>>::value &&
!platform::is_same<T, cutlass::complex<float>>::value, "May not be a complex data type");
return z;
}

View File

@ -121,7 +121,7 @@ struct ImplicitGemmConvolution {
// Conv2d row-major matrix C (KxRSC)
// Conv3d row-major matrix C (KxTRSC)
static int const kWgradCStrideIdx =
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
/// This chooses the appropriate stride element of the C tensor.
static int const kTensorCStrideIdx =

View File

@ -123,7 +123,7 @@ struct ImplicitGemmConvolutionFusion {
// Conv2d row-major matrix C (KxRSC)
// Conv3d row-major matrix C (KxTRSC)
static int const kWgradCStrideIdx =
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
/// This chooses the appropriate stride element of the C tensor.
static int const kTensorCStrideIdx =

View File

@ -121,20 +121,20 @@ struct ImplicitGemmConvolutionStridedDgrad {
// Conv2d row-major matrix C (KxRSC)
// Conv3d row-major matrix C (KxTRSC)
static int const kWgradCStrideIdx =
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
/// This chooses the appropriate stride element of the C tensor.
static int const kTensorCStrideIdx =
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
// Strided dgrad uses a specialized threadblock swizzle for functionality and performance
static_assert((std::is_same<ThreadblockSwizzle,
static_assert((platform::is_same<ThreadblockSwizzle,
threadblock::StridedDgradHorizontalThreadblockSwizzle>::value) ||
(std::is_same<ThreadblockSwizzle,
(platform::is_same<ThreadblockSwizzle,
threadblock::StridedDgradIdentityThreadblockSwizzle<1>>::value) ||
(std::is_same<ThreadblockSwizzle,
(platform::is_same<ThreadblockSwizzle,
threadblock::StridedDgradIdentityThreadblockSwizzle<4>>::value) ||
(std::is_same<ThreadblockSwizzle,
(platform::is_same<ThreadblockSwizzle,
threadblock::StridedDgradIdentityThreadblockSwizzle<8>>::value),
"Needs ThreadblockSwizzle type specialized for strided dgrad");

View File

@ -121,7 +121,7 @@ struct ImplicitGemmConvolutionWithFusedEpilogue {
// Conv2d row-major matrix C (KxRSC)
// Conv3d row-major matrix C (KxTRSC)
static int const kWgradCStrideIdx =
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
/// This chooses the appropriate stride element of the C tensor.
static int const kTensorCStrideIdx =

View File

@ -215,10 +215,10 @@ struct DefaultIteratorsTensorOp<
InstructionShape,
ThreadMap> {
static_assert(cutlass::platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
cutlass::platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
cutlass::platform::is_same<ElementOutput, int8_t>::value ||
cutlass::platform::is_same<ElementOutput, uint8_t>::value,
static_assert(platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
platform::is_same<ElementOutput, int8_t>::value ||
platform::is_same<ElementOutput, uint8_t>::value,
"ElementOutput needs to be 4 or 8 bit (unsigned) int.");
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8),

View File

@ -149,7 +149,7 @@ class Rank2K {
static int const kUpdateRank = 2;
// static asserts for rank 2k update kernel
static_assert(std::is_same<LayoutA, LayoutB>::value,
static_assert(platform::is_same<LayoutA, LayoutB>::value,
"Rank 2K update operator support same layouts for operandA and B");
/// Define the kernel

View File

@ -153,7 +153,7 @@ class Symm {
static BlasMode const kBlasMode = BlasMode_;
// static asserts for symm update kernel
static_assert(std::is_same<LayoutA, LayoutB>::value,
static_assert(platform::is_same<LayoutA, LayoutB>::value,
"SYMM update operator support same layouts for operand A and B");
/// Define the kernel

View File

@ -209,7 +209,7 @@ struct DefaultGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignment
2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
EpilogueOutputOp::kCount>::Epilogue;
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
RegularEpilogue,
Affine2Epilogue>::type;
@ -672,7 +672,7 @@ struct DefaultGemm<
kEpilogueElementsPerAccess
>::Epilogue;
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
RegularEpilogue,
Affine2Epilogue>::type;
@ -780,7 +780,7 @@ struct DefaultGemm<ElementA,
kEpilogueElementsPerAccess
>::Epilogue;
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
RegularEpilogue,
Affine2Epilogue>::type;

View File

@ -183,7 +183,7 @@ struct DefaultGemmGrouped<
> {
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
static bool const kInternalTranspose = std::is_same<LayoutC, layout::ColumnMajor>::value;
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
using MapArguments = kernel::detail::MapArguments<
ElementA,
@ -307,7 +307,7 @@ struct DefaultGemmGrouped<
> {
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
static bool const kInternalTranspose = std::is_same<LayoutC, layout::ColumnMajor>::value;
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
using MapArguments = kernel::detail::MapArguments<
ElementA,

View File

@ -67,7 +67,7 @@ public:
using LayoutA = layout::ColumnMajor;
using TensorRefA = TensorRef<ElementA, LayoutA>;
static_assert(std::is_same<LayoutA, LayoutA_>::value,
static_assert(platform::is_same<LayoutA, LayoutA_>::value,
"Only supported for column-major A matrix");
using ElementB = ElementB_;

View File

@ -632,8 +632,8 @@ struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
using ElementB = int8_t;
using OperatorClass = arch::OpClassSimt;
static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value;
static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value;
static const bool transposeA = platform::is_same< LayoutA, layout::ColumnMajor >::value;
static const bool transposeB = platform::is_same< LayoutB, layout::RowMajor >::value;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<

View File

@ -54,19 +54,19 @@ class UnaryOp
static FragmentOut execute(FragmentIn &in)
{
static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match.");
static_assert(std::is_same<Transform, UnaryTransform::Identity>::value ||
std::is_same<Transform, UnaryTransform::Conjugate>::value,
static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
platform::is_same<Transform, UnaryTransform::Conjugate>::value,
"Unary Operator not supported.");
FragmentOut out;
if( std::is_same<Transform, UnaryTransform::Identity>::value )
if( platform::is_same<Transform, UnaryTransform::Identity>::value )
{
CUTLASS_PRAGMA_UNROLL
for(int i=0; i < FragmentIn::kElements; ++i){
out[i] = static_cast<typename FragmentOut::Element>(in[i]);
}
}
else if( std::is_same<Transform, UnaryTransform::Conjugate>::value )
else if( platform::is_same<Transform, UnaryTransform::Conjugate>::value )
{
for(int i=0; i < FragmentIn::kElements; ++i){
out[i] = conj(static_cast<typename FragmentOut::Element>(in[i]));
@ -83,15 +83,15 @@ class UnaryOp<FragmentIn, FragmentIn, Transform>
CUTLASS_DEVICE
static FragmentIn execute(FragmentIn &in)
{
static_assert(std::is_same<Transform, UnaryTransform::Identity>::value ||
std::is_same<Transform, UnaryTransform::Conjugate>::value,
static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
platform::is_same<Transform, UnaryTransform::Conjugate>::value,
"Unary Operator not supported.");
if( std::is_same<Transform, UnaryTransform::Identity>::value )
if( platform::is_same<Transform, UnaryTransform::Identity>::value )
{
return in;
}
else if( std::is_same<Transform, UnaryTransform::Conjugate>::value )
else if( platform::is_same<Transform, UnaryTransform::Conjugate>::value )
{
for(int i=0; i < FragmentIn::kElements; ++i){
in[i] = conj(in[i]);