Updates for 3.0 (#857)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
parent
a68e2f95f0
commit
c4f6b8c6bc
@ -311,7 +311,9 @@ $ make cutlass_profiler -j16
|
|||||||
|
|
||||||
By default, only one tile size is instantiated for each data type, math instruction, and layout.
|
By default, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||||
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
||||||
Beware, this results in *thousands* of kernels and long build times.
|
Beware, this results in *tens of thousands* of kernels and long build times.
|
||||||
|
This would also result in a large binary size and on some platforms linker to fail on building the library.
|
||||||
|
Therefore, it's highly recommended to generate only a subset of kernels as demonstrated in the sub-section below.
|
||||||
```bash
|
```bash
|
||||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=all
|
$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=all
|
||||||
...
|
...
|
||||||
|
@ -399,7 +399,7 @@ struct alignas(1) float_e4m3_t : float8_base<FloatEncoding::E4M3> {
|
|||||||
|
|
||||||
return *reinterpret_cast<float_e4m3_t *>(&tmp);
|
return *reinterpret_cast<float_e4m3_t *>(&tmp);
|
||||||
#else
|
#else
|
||||||
return bitcast(Base::convert_float_to_fp8(float(flt)));
|
return bitcast(Base::convert_float_to_fp8(__half2float(flt)));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -413,7 +413,7 @@ struct alignas(1) float_e4m3_t : float8_base<FloatEncoding::E4M3> {
|
|||||||
|
|
||||||
return reinterpret_cast<half2 const &>(packed).x;
|
return reinterpret_cast<half2 const &>(packed).x;
|
||||||
#else
|
#else
|
||||||
return half(Base::convert_fp8_to_float(x.storage));
|
return __float2half(Base::convert_fp8_to_float(x.storage));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -425,7 +425,7 @@ struct alignas(1) float_e4m3_t : float8_base<FloatEncoding::E4M3> {
|
|||||||
uint32_t packed;
|
uint32_t packed;
|
||||||
asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
|
asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
|
||||||
|
|
||||||
return float(reinterpret_cast<half2 const &>(packed).x);
|
return __half2float(reinterpret_cast<half2 const &>(packed).x);
|
||||||
#else
|
#else
|
||||||
return Base::convert_fp8_to_float(x.storage);
|
return Base::convert_fp8_to_float(x.storage);
|
||||||
#endif
|
#endif
|
||||||
@ -609,7 +609,7 @@ struct alignas(1) float_e5m2_t : float8_base<FloatEncoding::E5M2> {
|
|||||||
|
|
||||||
return *reinterpret_cast<float_e5m2_t *>(&tmp);
|
return *reinterpret_cast<float_e5m2_t *>(&tmp);
|
||||||
#else
|
#else
|
||||||
return bitcast(Base::convert_float_to_fp8(float(flt)));
|
return bitcast(Base::convert_float_to_fp8(__half2float(flt)));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -623,7 +623,7 @@ struct alignas(1) float_e5m2_t : float8_base<FloatEncoding::E5M2> {
|
|||||||
|
|
||||||
return reinterpret_cast<half2 const &>(packed).x;
|
return reinterpret_cast<half2 const &>(packed).x;
|
||||||
#else
|
#else
|
||||||
return half(Base::convert_fp8_to_float(x.storage));
|
return __float2half(Base::convert_fp8_to_float(x.storage));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -635,7 +635,7 @@ struct alignas(1) float_e5m2_t : float8_base<FloatEncoding::E5M2> {
|
|||||||
uint32_t packed;
|
uint32_t packed;
|
||||||
asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
|
asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
|
||||||
|
|
||||||
return float(reinterpret_cast<half2 const &>(packed).x);
|
return __half2float(reinterpret_cast<half2 const &>(packed).x);
|
||||||
#else
|
#else
|
||||||
return Base::convert_fp8_to_float(x.storage);
|
return Base::convert_fp8_to_float(x.storage);
|
||||||
#endif
|
#endif
|
||||||
|
@ -89,6 +89,59 @@ struct multiplies {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__)
|
||||||
|
/// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set
|
||||||
|
template<>
|
||||||
|
struct plus<__half2> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
__half2 operator()(__half2 lhs, __half2 const &rhs) const {
|
||||||
|
return __hadd2(lhs, rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct minus<__half2> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
__half2 operator()(__half2 lhs, __half2 const &rhs) const {
|
||||||
|
return __hsub2(lhs, rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct multiplies<__half2> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
__half2 operator()(__half2 lhs, __half2 const &rhs) const {
|
||||||
|
return __hmul2(lhs, rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Partial specializations needed when __CUDA_NO_HALF_OPERATORS__ is set
|
||||||
|
template<>
|
||||||
|
struct plus<__half> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
__half operator()(__half lhs, __half const &rhs) const {
|
||||||
|
return __hadd(lhs, rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct minus<__half> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
__half operator()(__half lhs, __half const &rhs) const {
|
||||||
|
return __hsub(lhs, rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct multiplies<__half> {
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
__half operator()(__half lhs, __half const &rhs) const {
|
||||||
|
return __hmul(lhs, rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif // defined(__CUDA_ARCH__)
|
||||||
|
|
||||||
|
|
||||||
// Maximum with nan propogation
|
// Maximum with nan propogation
|
||||||
// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
|
// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -411,36 +464,15 @@ struct red<half2>
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void operator()(half2 *ptr, const half2 &data)
|
void operator()(half2 *ptr, const half2 &data)
|
||||||
{
|
{
|
||||||
#if !defined(__CUDA_ARCH__)
|
#if !defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600))
|
||||||
CUTLASS_UNUSED(ptr);
|
CUTLASS_UNUSED(ptr);
|
||||||
CUTLASS_UNUSED(data);
|
CUTLASS_UNUSED(data);
|
||||||
#elif (__CUDA_ARCH__ >= 600)
|
#else
|
||||||
|
|
||||||
// Vector-2 atomic reduction requires .target sm_60 or higher
|
// Vector-2 atomic reduction requires .target sm_60 or higher
|
||||||
uint32_t word = reinterpret_cast<const uint32_t&>(data);
|
uint32_t word = reinterpret_cast<const uint32_t&>(data);
|
||||||
asm volatile ("red.gpu.global.add.noftz.f16x2 [%0], %1;\n" : : "l"(ptr), "r"(word));
|
asm volatile ("red.gpu.global.add.noftz.f16x2 [%0], %1;\n" : : "l"(ptr), "r"(word));
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
// Use CAS loop
|
|
||||||
uint32_t *ptr_int = reinterpret_cast<uint32_t *>(ptr);
|
|
||||||
uint32_t old_int = *ptr_int;
|
|
||||||
uint32_t assumed_int;
|
|
||||||
|
|
||||||
do
|
|
||||||
{
|
|
||||||
half2 old = reinterpret_cast<half2&>(old_int);
|
|
||||||
|
|
||||||
half hi = __hadd(__high2half(old), __high2half(data));
|
|
||||||
half lo = __hadd(__low2half(old), __low2half(data));
|
|
||||||
half2 update = __halves2half2(hi, lo);
|
|
||||||
uint32_t update_int = reinterpret_cast<const uint32_t&>(update);
|
|
||||||
|
|
||||||
assumed_int = old_int;
|
|
||||||
old_int = atomicCAS(ptr_int, assumed_int, update_int);
|
|
||||||
|
|
||||||
} while (assumed_int != old_int);
|
|
||||||
|
|
||||||
#endif // (__CUDA_ARCH__ >= 600)
|
#endif // (__CUDA_ARCH__ >= 600)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -52,7 +52,10 @@ template <
|
|||||||
/// Element type
|
/// Element type
|
||||||
typename T,
|
typename T,
|
||||||
/// Number of elements in the array
|
/// Number of elements in the array
|
||||||
int N
|
int N,
|
||||||
|
/// Whether the element type of T is half_t or __half
|
||||||
|
bool IsHalfType = (platform::is_same<typename T::element_type, cutlass::half_t>::value ||
|
||||||
|
platform::is_same<typename T::element_type, __half>::value)
|
||||||
>
|
>
|
||||||
class WmmaFragmentArray: public Array<T, N, true> {
|
class WmmaFragmentArray: public Array<T, N, true> {
|
||||||
public:
|
public:
|
||||||
@ -80,7 +83,44 @@ public:
|
|||||||
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Partial specialization for the case in which T::element_type is
|
||||||
|
/// half_t or __half. This is needed because the cast (typename T::element_type)0
|
||||||
|
/// in the primary template flags as an error when __CUDA_NO_HALF_CONVERSIONS__
|
||||||
|
/// is set.
|
||||||
|
template <
|
||||||
|
/// Element type
|
||||||
|
typename T,
|
||||||
|
/// Number of elements in the array
|
||||||
|
int N
|
||||||
|
>
|
||||||
|
class WmmaFragmentArray<T, N, true>: public Array<T, N, true> {
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Efficient clear method (override Array::clear())
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
void clear()
|
||||||
|
{
|
||||||
|
for(int i = 0; i < Array<T, N, true>::kElements; i++)
|
||||||
|
{
|
||||||
|
nvcuda::wmma::fill_fragment((*this)[i], __float2half(0.f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
WmmaFragmentArray<T, N>& operator+=(const WmmaFragmentArray<T, N>& rhs)
|
||||||
|
{
|
||||||
|
using element_type = typename T::element_type;
|
||||||
|
plus<T> add;
|
||||||
|
|
||||||
|
for (int i = 0; i < Array<T, N, true>::kElements; i++)
|
||||||
|
{
|
||||||
|
(*this)[i] = add((*this)[i], rhs[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -587,7 +587,7 @@ To instantiate all operations supporting all tile sizes, data types, and alignme
|
|||||||
```bash
|
```bash
|
||||||
$ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=all
|
$ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=all
|
||||||
```
|
```
|
||||||
The above command line generates about seven thousand kernels targetting NVIDIA Ampere, Turing, and Volta architectures.
|
The above command line generates about twenty thousand kernels targetting NVIDIA Ampere, Turing, and Volta architectures.
|
||||||
Compiling thousands of kernels for three different architectures is time consuming. Additionaly, this would also result
|
Compiling thousands of kernels for three different architectures is time consuming. Additionaly, this would also result
|
||||||
in a large binary size and on some platforms linker to fail on building the library.
|
in a large binary size and on some platforms linker to fail on building the library.
|
||||||
|
|
||||||
|
@ -100,6 +100,7 @@ execute_process(
|
|||||||
--kernels "${CUTLASS_LIBRARY_KERNELS}"
|
--kernels "${CUTLASS_LIBRARY_KERNELS}"
|
||||||
--ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}"
|
--ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}"
|
||||||
--cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}"
|
--cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}"
|
||||||
|
--log-level DEBUG
|
||||||
RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT
|
RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT
|
||||||
OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT
|
OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT
|
||||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log
|
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log
|
||||||
|
@ -8,6 +8,7 @@ import enum
|
|||||||
import os.path
|
import os.path
|
||||||
import shutil
|
import shutil
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
from library import *
|
from library import *
|
||||||
from manifest import *
|
from manifest import *
|
||||||
@ -4838,8 +4839,30 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
|
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
|
||||||
parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures")
|
parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures")
|
||||||
|
|
||||||
|
def numeric_log_level(log_level: str) -> int:
|
||||||
|
"""
|
||||||
|
Converts the string identifier of the log level into the numeric identifier used
|
||||||
|
in setting the log level
|
||||||
|
|
||||||
|
:param x: string representation of log level (e.g., 'INFO', 'DEBUG')
|
||||||
|
:type x: str
|
||||||
|
|
||||||
|
:return: numeric representation of log level
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
numeric_level = getattr(logging, log_level.upper(), None)
|
||||||
|
if not isinstance(numeric_level, int):
|
||||||
|
raise ValueError(f'Invalid log level: {log_level}')
|
||||||
|
return numeric_level
|
||||||
|
|
||||||
|
parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False,
|
||||||
|
help='Logging level to be used by the generator script')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set the logging level based on the user-provided `--log-level` command-line option
|
||||||
|
logging.basicConfig(level=args.log_level)
|
||||||
|
|
||||||
manifest = Manifest(args)
|
manifest = Manifest(args)
|
||||||
|
|
||||||
GenerateSM50(manifest, args.cuda_version)
|
GenerateSM50(manifest, args.cuda_version)
|
||||||
@ -4849,6 +4872,7 @@ if __name__ == "__main__":
|
|||||||
GenerateSM75(manifest, args.cuda_version)
|
GenerateSM75(manifest, args.cuda_version)
|
||||||
GenerateSM80(manifest, args.cuda_version)
|
GenerateSM80(manifest, args.cuda_version)
|
||||||
GenerateSM90(manifest, args.cuda_version)
|
GenerateSM90(manifest, args.cuda_version)
|
||||||
|
|
||||||
if 'library' in args.generator_target.split(','):
|
if 'library' in args.generator_target.split(','):
|
||||||
manifest.emit(GeneratorTarget.Library)
|
manifest.emit(GeneratorTarget.Library)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user