Add definitions for tag structs. (#752)

This commit changes the declarations of MMA operator class (SIMT, Tensor Core, WMMA Tensor Core) and operator type (multiply-add and so on) to definitions. This is done so that these tag structs are no longer incomplete types, which allows the `typeid` operator to be used on these tag structs. This is necessary for these tag structs to be used as type parameters in [GoogleTest typed tests](https://google.github.io/googletest/advanced.html#typed-tests).
This commit is contained in:
Gregory Meyer (gregjm) 2023-01-06 06:46:52 -08:00 committed by GitHub
parent c54ede3a9e
commit 7bdba07310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -49,61 +49,61 @@ namespace arch {
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the operation implied by MMA. /// Tag indicating the operation implied by MMA.
struct OpMultiplyAdd; struct OpMultiplyAdd {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the result is saturated to MAX_FLOAT|MIN_FLOAT or MAX_INT|MIN_INT /// Tag indicating the result is saturated to MAX_FLOAT|MIN_FLOAT or MAX_INT|MIN_INT
struct OpMultiplyAddSaturate; struct OpMultiplyAddSaturate {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to a narrower type (BF16) /// Tag indicating the input is converted to a narrower type (BF16)
struct OpMultiplyAddFastBF16; struct OpMultiplyAddFastBF16 {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to a narrower type (F16) /// Tag indicating the input is converted to a narrower type (F16)
struct OpMultiplyAddFastF16; struct OpMultiplyAddFastF16 {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to 2 (big and small) TF32 components /// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every F32 output element // Perform 3xTF32 or 4xTF32 for every F32 output element
struct OpMultiplyAddFastF32; struct OpMultiplyAddFastF32 {}
/// Tag indicating the input is converted to 2 (big and small) TF32 components /// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element // Perform 3xTF32 or 4xTF32 for every complex<F32> output element
struct OpMultiplyAddComplexFastF32; struct OpMultiplyAddComplexFastF32 {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the complex multiply-add operation /// Tag indicating the complex multiply-add operation
struct OpMultiplyAddComplex; struct OpMultiplyAddComplex {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the gaussian complex multiply-add operation /// Tag indicating the gaussian complex multiply-add operation
struct OpMultiplyAddGaussianComplex; struct OpMultiplyAddGaussianComplex {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the inner product is defined by (XOR, POPC) /// Tag indicating the inner product is defined by (XOR, POPC)
struct OpXorPopc; struct OpXorPopc {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag classifying math operators as thread-level operations. /// Tag classifying math operators as thread-level operations.
struct OpClassSimt; struct OpClassSimt {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag classifing operators as Tensor Core operations. /// Tag classifing operators as Tensor Core operations.
struct OpClassTensorOp; struct OpClassTensorOp {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag classifing operators as WMMA Tensor Core operations /// Tag classifing operators as WMMA Tensor Core operations
struct OpClassWmmaTensorOp; struct OpClassWmmaTensorOp {};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////