2.9 fixes for nvrtc (#480)
* Use platform::is_same instead of std::is_same * Don't hide cuComplex include from nvrtc * Typo fixed * Remove comment rename
This commit is contained in:
parent
21c1fa3849
commit
86ce09aed1
@ -30,11 +30,12 @@
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
|
||||
#if defined(__CUDACC_RTC__)
|
||||
#include <cuda/std/cstdint>
|
||||
#else
|
||||
#include <cstdint>
|
||||
#include <cuComplex.h>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -435,10 +436,10 @@ CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
|
||||
/// Indentity transform for non-complex types
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T conj(T const &z) {
|
||||
static_assert( !std::is_same<T, cuComplex>::value &&
|
||||
!std::is_same<T, cuDoubleComplex>::value &&
|
||||
!std::is_same<T, cutlass::complex<double>>::value &&
|
||||
!std::is_same<T, cutlass::complex<float>>::value, "May not be a complex data type");
|
||||
static_assert( !platform::is_same<T, cuComplex>::value &&
|
||||
!platform::is_same<T, cuDoubleComplex>::value &&
|
||||
!platform::is_same<T, cutlass::complex<double>>::value &&
|
||||
!platform::is_same<T, cutlass::complex<float>>::value, "May not be a complex data type");
|
||||
return z;
|
||||
}
|
||||
|
||||
|
@ -121,7 +121,7 @@ struct ImplicitGemmConvolution {
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
|
@ -123,7 +123,7 @@ struct ImplicitGemmConvolutionFusion {
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
|
@ -121,20 +121,20 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
||||
|
||||
// Strided dgrad uses a specialized threadblock swizzle for functionality and performance
|
||||
static_assert((std::is_same<ThreadblockSwizzle,
|
||||
static_assert((platform::is_same<ThreadblockSwizzle,
|
||||
threadblock::StridedDgradHorizontalThreadblockSwizzle>::value) ||
|
||||
(std::is_same<ThreadblockSwizzle,
|
||||
(platform::is_same<ThreadblockSwizzle,
|
||||
threadblock::StridedDgradIdentityThreadblockSwizzle<1>>::value) ||
|
||||
(std::is_same<ThreadblockSwizzle,
|
||||
(platform::is_same<ThreadblockSwizzle,
|
||||
threadblock::StridedDgradIdentityThreadblockSwizzle<4>>::value) ||
|
||||
(std::is_same<ThreadblockSwizzle,
|
||||
(platform::is_same<ThreadblockSwizzle,
|
||||
threadblock::StridedDgradIdentityThreadblockSwizzle<8>>::value),
|
||||
"Needs ThreadblockSwizzle type specialized for strided dgrad");
|
||||
|
||||
|
@ -121,7 +121,7 @@ struct ImplicitGemmConvolutionWithFusedEpilogue {
|
||||
// Conv2d row-major matrix C (KxRSC)
|
||||
// Conv3d row-major matrix C (KxTRSC)
|
||||
static int const kWgradCStrideIdx =
|
||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||
|
||||
/// This chooses the appropriate stride element of the C tensor.
|
||||
static int const kTensorCStrideIdx =
|
||||
|
@ -215,10 +215,10 @@ struct DefaultIteratorsTensorOp<
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
|
||||
static_assert(cutlass::platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
||||
cutlass::platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
||||
cutlass::platform::is_same<ElementOutput, int8_t>::value ||
|
||||
cutlass::platform::is_same<ElementOutput, uint8_t>::value,
|
||||
static_assert(platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
||||
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
||||
platform::is_same<ElementOutput, int8_t>::value ||
|
||||
platform::is_same<ElementOutput, uint8_t>::value,
|
||||
"ElementOutput needs to be 4 or 8 bit (unsigned) int.");
|
||||
|
||||
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8),
|
||||
|
@ -149,7 +149,7 @@ class Rank2K {
|
||||
static int const kUpdateRank = 2;
|
||||
|
||||
// static asserts for rank 2k update kernel
|
||||
static_assert(std::is_same<LayoutA, LayoutB>::value,
|
||||
static_assert(platform::is_same<LayoutA, LayoutB>::value,
|
||||
"Rank 2K update operator support same layouts for operandA and B");
|
||||
|
||||
/// Define the kernel
|
||||
|
@ -153,7 +153,7 @@ class Symm {
|
||||
static BlasMode const kBlasMode = BlasMode_;
|
||||
|
||||
// static asserts for symm update kernel
|
||||
static_assert(std::is_same<LayoutA, LayoutB>::value,
|
||||
static_assert(platform::is_same<LayoutA, LayoutB>::value,
|
||||
"SYMM update operator support same layouts for operand A and B");
|
||||
|
||||
/// Define the kernel
|
||||
|
@ -209,7 +209,7 @@ struct DefaultGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignment
|
||||
2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||
EpilogueOutputOp::kCount>::Epilogue;
|
||||
|
||||
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||
RegularEpilogue,
|
||||
Affine2Epilogue>::type;
|
||||
|
||||
@ -672,7 +672,7 @@ struct DefaultGemm<
|
||||
kEpilogueElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||
RegularEpilogue,
|
||||
Affine2Epilogue>::type;
|
||||
|
||||
@ -780,7 +780,7 @@ struct DefaultGemm<ElementA,
|
||||
kEpilogueElementsPerAccess
|
||||
>::Epilogue;
|
||||
|
||||
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||
RegularEpilogue,
|
||||
Affine2Epilogue>::type;
|
||||
|
||||
|
@ -183,7 +183,7 @@ struct DefaultGemmGrouped<
|
||||
> {
|
||||
|
||||
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
||||
static bool const kInternalTranspose = std::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||
|
||||
using MapArguments = kernel::detail::MapArguments<
|
||||
ElementA,
|
||||
@ -307,7 +307,7 @@ struct DefaultGemmGrouped<
|
||||
> {
|
||||
|
||||
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
||||
static bool const kInternalTranspose = std::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||
|
||||
using MapArguments = kernel::detail::MapArguments<
|
||||
ElementA,
|
||||
|
@ -67,7 +67,7 @@ public:
|
||||
using LayoutA = layout::ColumnMajor;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
static_assert(std::is_same<LayoutA, LayoutA_>::value,
|
||||
static_assert(platform::is_same<LayoutA, LayoutA_>::value,
|
||||
"Only supported for column-major A matrix");
|
||||
|
||||
using ElementB = ElementB_;
|
||||
|
@ -632,8 +632,8 @@ struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
|
||||
using ElementB = int8_t;
|
||||
using OperatorClass = arch::OpClassSimt;
|
||||
|
||||
static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value;
|
||||
static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value;
|
||||
static const bool transposeA = platform::is_same< LayoutA, layout::ColumnMajor >::value;
|
||||
static const bool transposeB = platform::is_same< LayoutB, layout::RowMajor >::value;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
|
@ -54,19 +54,19 @@ class UnaryOp
|
||||
static FragmentOut execute(FragmentIn &in)
|
||||
{
|
||||
static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match.");
|
||||
static_assert(std::is_same<Transform, UnaryTransform::Identity>::value ||
|
||||
std::is_same<Transform, UnaryTransform::Conjugate>::value,
|
||||
static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
|
||||
platform::is_same<Transform, UnaryTransform::Conjugate>::value,
|
||||
"Unary Operator not supported.");
|
||||
|
||||
FragmentOut out;
|
||||
if( std::is_same<Transform, UnaryTransform::Identity>::value )
|
||||
if( platform::is_same<Transform, UnaryTransform::Identity>::value )
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i=0; i < FragmentIn::kElements; ++i){
|
||||
out[i] = static_cast<typename FragmentOut::Element>(in[i]);
|
||||
}
|
||||
}
|
||||
else if( std::is_same<Transform, UnaryTransform::Conjugate>::value )
|
||||
else if( platform::is_same<Transform, UnaryTransform::Conjugate>::value )
|
||||
{
|
||||
for(int i=0; i < FragmentIn::kElements; ++i){
|
||||
out[i] = conj(static_cast<typename FragmentOut::Element>(in[i]));
|
||||
@ -83,15 +83,15 @@ class UnaryOp<FragmentIn, FragmentIn, Transform>
|
||||
CUTLASS_DEVICE
|
||||
static FragmentIn execute(FragmentIn &in)
|
||||
{
|
||||
static_assert(std::is_same<Transform, UnaryTransform::Identity>::value ||
|
||||
std::is_same<Transform, UnaryTransform::Conjugate>::value,
|
||||
static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
|
||||
platform::is_same<Transform, UnaryTransform::Conjugate>::value,
|
||||
"Unary Operator not supported.");
|
||||
|
||||
if( std::is_same<Transform, UnaryTransform::Identity>::value )
|
||||
if( platform::is_same<Transform, UnaryTransform::Identity>::value )
|
||||
{
|
||||
return in;
|
||||
}
|
||||
else if( std::is_same<Transform, UnaryTransform::Conjugate>::value )
|
||||
else if( platform::is_same<Transform, UnaryTransform::Conjugate>::value )
|
||||
{
|
||||
for(int i=0; i < FragmentIn::kElements; ++i){
|
||||
in[i] = conj(in[i]);
|
||||
|
Loading…
Reference in New Issue
Block a user