2.9 fixes for nvrtc (#480)
* Use platform::is_same instead of std::is_same * Don't hide cuComplex include from nvrtc * Typo fixed * Remove comment rename
This commit is contained in:
parent
21c1fa3849
commit
86ce09aed1
@ -30,11 +30,12 @@
|
|||||||
**************************************************************************************************/
|
**************************************************************************************************/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuComplex.h>
|
||||||
|
|
||||||
#if defined(__CUDACC_RTC__)
|
#if defined(__CUDACC_RTC__)
|
||||||
#include <cuda/std/cstdint>
|
#include <cuda/std/cstdint>
|
||||||
#else
|
#else
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cuComplex.h>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
@ -435,10 +436,10 @@ CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
|
|||||||
/// Indentity transform for non-complex types
|
/// Indentity transform for non-complex types
|
||||||
template <typename T>
|
template <typename T>
|
||||||
CUTLASS_HOST_DEVICE T conj(T const &z) {
|
CUTLASS_HOST_DEVICE T conj(T const &z) {
|
||||||
static_assert( !std::is_same<T, cuComplex>::value &&
|
static_assert( !platform::is_same<T, cuComplex>::value &&
|
||||||
!std::is_same<T, cuDoubleComplex>::value &&
|
!platform::is_same<T, cuDoubleComplex>::value &&
|
||||||
!std::is_same<T, cutlass::complex<double>>::value &&
|
!platform::is_same<T, cutlass::complex<double>>::value &&
|
||||||
!std::is_same<T, cutlass::complex<float>>::value, "May not be a complex data type");
|
!platform::is_same<T, cutlass::complex<float>>::value, "May not be a complex data type");
|
||||||
return z;
|
return z;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ struct ImplicitGemmConvolution {
|
|||||||
// Conv2d row-major matrix C (KxRSC)
|
// Conv2d row-major matrix C (KxRSC)
|
||||||
// Conv3d row-major matrix C (KxTRSC)
|
// Conv3d row-major matrix C (KxTRSC)
|
||||||
static int const kWgradCStrideIdx =
|
static int const kWgradCStrideIdx =
|
||||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||||
|
|
||||||
/// This chooses the appropriate stride element of the C tensor.
|
/// This chooses the appropriate stride element of the C tensor.
|
||||||
static int const kTensorCStrideIdx =
|
static int const kTensorCStrideIdx =
|
||||||
|
@ -123,7 +123,7 @@ struct ImplicitGemmConvolutionFusion {
|
|||||||
// Conv2d row-major matrix C (KxRSC)
|
// Conv2d row-major matrix C (KxRSC)
|
||||||
// Conv3d row-major matrix C (KxTRSC)
|
// Conv3d row-major matrix C (KxTRSC)
|
||||||
static int const kWgradCStrideIdx =
|
static int const kWgradCStrideIdx =
|
||||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||||
|
|
||||||
/// This chooses the appropriate stride element of the C tensor.
|
/// This chooses the appropriate stride element of the C tensor.
|
||||||
static int const kTensorCStrideIdx =
|
static int const kTensorCStrideIdx =
|
||||||
|
@ -121,20 +121,20 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
|||||||
// Conv2d row-major matrix C (KxRSC)
|
// Conv2d row-major matrix C (KxRSC)
|
||||||
// Conv3d row-major matrix C (KxTRSC)
|
// Conv3d row-major matrix C (KxTRSC)
|
||||||
static int const kWgradCStrideIdx =
|
static int const kWgradCStrideIdx =
|
||||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||||
|
|
||||||
/// This chooses the appropriate stride element of the C tensor.
|
/// This chooses the appropriate stride element of the C tensor.
|
||||||
static int const kTensorCStrideIdx =
|
static int const kTensorCStrideIdx =
|
||||||
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
(kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0);
|
||||||
|
|
||||||
// Strided dgrad uses a specialized threadblock swizzle for functionality and performance
|
// Strided dgrad uses a specialized threadblock swizzle for functionality and performance
|
||||||
static_assert((std::is_same<ThreadblockSwizzle,
|
static_assert((platform::is_same<ThreadblockSwizzle,
|
||||||
threadblock::StridedDgradHorizontalThreadblockSwizzle>::value) ||
|
threadblock::StridedDgradHorizontalThreadblockSwizzle>::value) ||
|
||||||
(std::is_same<ThreadblockSwizzle,
|
(platform::is_same<ThreadblockSwizzle,
|
||||||
threadblock::StridedDgradIdentityThreadblockSwizzle<1>>::value) ||
|
threadblock::StridedDgradIdentityThreadblockSwizzle<1>>::value) ||
|
||||||
(std::is_same<ThreadblockSwizzle,
|
(platform::is_same<ThreadblockSwizzle,
|
||||||
threadblock::StridedDgradIdentityThreadblockSwizzle<4>>::value) ||
|
threadblock::StridedDgradIdentityThreadblockSwizzle<4>>::value) ||
|
||||||
(std::is_same<ThreadblockSwizzle,
|
(platform::is_same<ThreadblockSwizzle,
|
||||||
threadblock::StridedDgradIdentityThreadblockSwizzle<8>>::value),
|
threadblock::StridedDgradIdentityThreadblockSwizzle<8>>::value),
|
||||||
"Needs ThreadblockSwizzle type specialized for strided dgrad");
|
"Needs ThreadblockSwizzle type specialized for strided dgrad");
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ struct ImplicitGemmConvolutionWithFusedEpilogue {
|
|||||||
// Conv2d row-major matrix C (KxRSC)
|
// Conv2d row-major matrix C (KxRSC)
|
||||||
// Conv3d row-major matrix C (KxTRSC)
|
// Conv3d row-major matrix C (KxTRSC)
|
||||||
static int const kWgradCStrideIdx =
|
static int const kWgradCStrideIdx =
|
||||||
cutlass::platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
platform::is_same<LayoutC, cutlass::layout::TensorNHWC>::value ? 2 : 3;
|
||||||
|
|
||||||
/// This chooses the appropriate stride element of the C tensor.
|
/// This chooses the appropriate stride element of the C tensor.
|
||||||
static int const kTensorCStrideIdx =
|
static int const kTensorCStrideIdx =
|
||||||
|
@ -215,10 +215,10 @@ struct DefaultIteratorsTensorOp<
|
|||||||
InstructionShape,
|
InstructionShape,
|
||||||
ThreadMap> {
|
ThreadMap> {
|
||||||
|
|
||||||
static_assert(cutlass::platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
static_assert(platform::is_same<ElementOutput, cutlass::int4b_t>::value ||
|
||||||
cutlass::platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
platform::is_same<ElementOutput, cutlass::uint4b_t>::value ||
|
||||||
cutlass::platform::is_same<ElementOutput, int8_t>::value ||
|
platform::is_same<ElementOutput, int8_t>::value ||
|
||||||
cutlass::platform::is_same<ElementOutput, uint8_t>::value,
|
platform::is_same<ElementOutput, uint8_t>::value,
|
||||||
"ElementOutput needs to be 4 or 8 bit (unsigned) int.");
|
"ElementOutput needs to be 4 or 8 bit (unsigned) int.");
|
||||||
|
|
||||||
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8),
|
static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8),
|
||||||
|
@ -149,7 +149,7 @@ class Rank2K {
|
|||||||
static int const kUpdateRank = 2;
|
static int const kUpdateRank = 2;
|
||||||
|
|
||||||
// static asserts for rank 2k update kernel
|
// static asserts for rank 2k update kernel
|
||||||
static_assert(std::is_same<LayoutA, LayoutB>::value,
|
static_assert(platform::is_same<LayoutA, LayoutB>::value,
|
||||||
"Rank 2K update operator support same layouts for operandA and B");
|
"Rank 2K update operator support same layouts for operandA and B");
|
||||||
|
|
||||||
/// Define the kernel
|
/// Define the kernel
|
||||||
|
@ -153,7 +153,7 @@ class Symm {
|
|||||||
static BlasMode const kBlasMode = BlasMode_;
|
static BlasMode const kBlasMode = BlasMode_;
|
||||||
|
|
||||||
// static asserts for symm update kernel
|
// static asserts for symm update kernel
|
||||||
static_assert(std::is_same<LayoutA, LayoutB>::value,
|
static_assert(platform::is_same<LayoutA, LayoutB>::value,
|
||||||
"SYMM update operator support same layouts for operand A and B");
|
"SYMM update operator support same layouts for operand A and B");
|
||||||
|
|
||||||
/// Define the kernel
|
/// Define the kernel
|
||||||
|
@ -209,7 +209,7 @@ struct DefaultGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignment
|
|||||||
2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
|
||||||
EpilogueOutputOp::kCount>::Epilogue;
|
EpilogueOutputOp::kCount>::Epilogue;
|
||||||
|
|
||||||
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
|
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||||
RegularEpilogue,
|
RegularEpilogue,
|
||||||
Affine2Epilogue>::type;
|
Affine2Epilogue>::type;
|
||||||
|
|
||||||
@ -672,7 +672,7 @@ struct DefaultGemm<
|
|||||||
kEpilogueElementsPerAccess
|
kEpilogueElementsPerAccess
|
||||||
>::Epilogue;
|
>::Epilogue;
|
||||||
|
|
||||||
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
|
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||||
RegularEpilogue,
|
RegularEpilogue,
|
||||||
Affine2Epilogue>::type;
|
Affine2Epilogue>::type;
|
||||||
|
|
||||||
@ -780,7 +780,7 @@ struct DefaultGemm<ElementA,
|
|||||||
kEpilogueElementsPerAccess
|
kEpilogueElementsPerAccess
|
||||||
>::Epilogue;
|
>::Epilogue;
|
||||||
|
|
||||||
using Epilogue = typename cutlass::platform::conditional<cutlass::platform::is_same<LayoutC, layout::RowMajor>::value,
|
using Epilogue = typename cutlass::platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
|
||||||
RegularEpilogue,
|
RegularEpilogue,
|
||||||
Affine2Epilogue>::type;
|
Affine2Epilogue>::type;
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ struct DefaultGemmGrouped<
|
|||||||
> {
|
> {
|
||||||
|
|
||||||
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
||||||
static bool const kInternalTranspose = std::is_same<LayoutC, layout::ColumnMajor>::value;
|
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||||
|
|
||||||
using MapArguments = kernel::detail::MapArguments<
|
using MapArguments = kernel::detail::MapArguments<
|
||||||
ElementA,
|
ElementA,
|
||||||
@ -307,7 +307,7 @@ struct DefaultGemmGrouped<
|
|||||||
> {
|
> {
|
||||||
|
|
||||||
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
||||||
static bool const kInternalTranspose = std::is_same<LayoutC, layout::ColumnMajor>::value;
|
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||||
|
|
||||||
using MapArguments = kernel::detail::MapArguments<
|
using MapArguments = kernel::detail::MapArguments<
|
||||||
ElementA,
|
ElementA,
|
||||||
|
@ -67,7 +67,7 @@ public:
|
|||||||
using LayoutA = layout::ColumnMajor;
|
using LayoutA = layout::ColumnMajor;
|
||||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||||
|
|
||||||
static_assert(std::is_same<LayoutA, LayoutA_>::value,
|
static_assert(platform::is_same<LayoutA, LayoutA_>::value,
|
||||||
"Only supported for column-major A matrix");
|
"Only supported for column-major A matrix");
|
||||||
|
|
||||||
using ElementB = ElementB_;
|
using ElementB = ElementB_;
|
||||||
|
@ -632,8 +632,8 @@ struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
|
|||||||
using ElementB = int8_t;
|
using ElementB = int8_t;
|
||||||
using OperatorClass = arch::OpClassSimt;
|
using OperatorClass = arch::OpClassSimt;
|
||||||
|
|
||||||
static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value;
|
static const bool transposeA = platform::is_same< LayoutA, layout::ColumnMajor >::value;
|
||||||
static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value;
|
static const bool transposeB = platform::is_same< LayoutB, layout::RowMajor >::value;
|
||||||
|
|
||||||
// Define the MmaCore components
|
// Define the MmaCore components
|
||||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||||
|
@ -54,19 +54,19 @@ class UnaryOp
|
|||||||
static FragmentOut execute(FragmentIn &in)
|
static FragmentOut execute(FragmentIn &in)
|
||||||
{
|
{
|
||||||
static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match.");
|
static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match.");
|
||||||
static_assert(std::is_same<Transform, UnaryTransform::Identity>::value ||
|
static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
|
||||||
std::is_same<Transform, UnaryTransform::Conjugate>::value,
|
platform::is_same<Transform, UnaryTransform::Conjugate>::value,
|
||||||
"Unary Operator not supported.");
|
"Unary Operator not supported.");
|
||||||
|
|
||||||
FragmentOut out;
|
FragmentOut out;
|
||||||
if( std::is_same<Transform, UnaryTransform::Identity>::value )
|
if( platform::is_same<Transform, UnaryTransform::Identity>::value )
|
||||||
{
|
{
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for(int i=0; i < FragmentIn::kElements; ++i){
|
for(int i=0; i < FragmentIn::kElements; ++i){
|
||||||
out[i] = static_cast<typename FragmentOut::Element>(in[i]);
|
out[i] = static_cast<typename FragmentOut::Element>(in[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if( std::is_same<Transform, UnaryTransform::Conjugate>::value )
|
else if( platform::is_same<Transform, UnaryTransform::Conjugate>::value )
|
||||||
{
|
{
|
||||||
for(int i=0; i < FragmentIn::kElements; ++i){
|
for(int i=0; i < FragmentIn::kElements; ++i){
|
||||||
out[i] = conj(static_cast<typename FragmentOut::Element>(in[i]));
|
out[i] = conj(static_cast<typename FragmentOut::Element>(in[i]));
|
||||||
@ -83,15 +83,15 @@ class UnaryOp<FragmentIn, FragmentIn, Transform>
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
static FragmentIn execute(FragmentIn &in)
|
static FragmentIn execute(FragmentIn &in)
|
||||||
{
|
{
|
||||||
static_assert(std::is_same<Transform, UnaryTransform::Identity>::value ||
|
static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
|
||||||
std::is_same<Transform, UnaryTransform::Conjugate>::value,
|
platform::is_same<Transform, UnaryTransform::Conjugate>::value,
|
||||||
"Unary Operator not supported.");
|
"Unary Operator not supported.");
|
||||||
|
|
||||||
if( std::is_same<Transform, UnaryTransform::Identity>::value )
|
if( platform::is_same<Transform, UnaryTransform::Identity>::value )
|
||||||
{
|
{
|
||||||
return in;
|
return in;
|
||||||
}
|
}
|
||||||
else if( std::is_same<Transform, UnaryTransform::Conjugate>::value )
|
else if( platform::is_same<Transform, UnaryTransform::Conjugate>::value )
|
||||||
{
|
{
|
||||||
for(int i=0; i < FragmentIn::kElements; ++i){
|
for(int i=0; i < FragmentIn::kElements; ++i){
|
||||||
in[i] = conj(in[i]);
|
in[i] = conj(in[i]);
|
||||||
|
Loading…
Reference in New Issue
Block a user