/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include // Config #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) # define CUTE_ARCH_MMA_SM90_ENABLED # define CUTE_ARCH_MMA_F64_SM90_ENABLED #endif //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cute { //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x4 TN struct SM90_16x8x4_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[2]; using BRegisters = double[1]; using CRegisters = double[4]; CUTE_HOST_DEVICE static void fma(double & d0, double & d1, double & d2, double & d3, double const& a0, double const& a1, double const& b0, double const& c0, double const& c1, double const& c2, double const& c3) { #if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) asm volatile( "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) : "d"(a0), "d"(a1), "d"(b0), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x8 TN struct SM90_16x8x8_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[4]; using BRegisters = double[2]; using CRegisters = double[4]; CUTE_HOST_DEVICE static void fma(double & d0, double & d1, double & d2, double & d3, double const& a0, double const& a1, double const& a2, double const& a3, double const& b0, double const& b1, double const& c0, double const& c1, double const& c2, double const& c3) { #if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) asm volatile( "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) : "d"(a0), "d"(a1), "d"(a2), "d"(a3), "d"(b0), "d"(b1), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x16 TN struct SM90_16x8x16_F64F64F64F64_TN { using DRegisters = double[4]; using ARegisters = double[8]; using BRegisters = double[4]; using CRegisters = double[4]; CUTE_HOST_DEVICE static void fma(double & d0, double & d1, double & d2, double & d3, double const& a0, double const& a1, double const& a2, double const& a3, double const& a4, double const& a5, double const& a6, double const& a7, double const& b0, double const& b1, double const& b2, double const& b3, double const& c0, double const& c1, double const& c2, double const& c3) { #if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" "{%0, %1, %2, %3}," "{%4, %5, %6, %7, %8, %9, %10, %11}," "{%12, %13, %14, %15}," "{%16, %17, %18, %19};\n" : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) : "d"(a0), "d"(a1), "d"(a2), "d"(a3), "d"(a4), "d"(a5), "d"(a6), "d"(a7), "d"(b0), "d"(b1), "d"(b2), "d"(b3), "d"(c0), "d"(c1), "d"(c2), "d"(c3)); #else CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x4 TN struct SM90_16x8x4_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[2]; using BRegisters = complex[1]; using CRegisters = complex[4]; CUTE_HOST_DEVICE static void fma(complex & d0, complex & d1, complex & d2, complex & d3, complex const& a0, complex const& a1, complex const& b0, complex const& c0, complex const& c1, complex const& c2, complex const& c3) { // Because thrust::complex does not provide a mutable ref double& rd0 = reinterpret_cast(d0)[0]; double& id0 = reinterpret_cast(d0)[1]; double& rd1 = reinterpret_cast(d1)[0]; double& id1 = reinterpret_cast(d1)[1]; double& rd2 = reinterpret_cast(d2)[0]; double& id2 = reinterpret_cast(d2)[1]; double& rd3 = reinterpret_cast(d3)[0]; double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); SM90_16x8x4_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), b0.real(), c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); SM90_16x8x4_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), b0.real(), c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); SM90_16x8x4_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), b0.imag(), d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); SM90_16x8x4_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), b0.imag(), d0.imag(), d1.imag(), d2.imag(), d3.imag()); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x8 TN struct SM90_16x8x8_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[4]; using BRegisters = complex[2]; using CRegisters = complex[4]; CUTE_HOST_DEVICE static void fma(complex & d0, complex & d1, complex & d2, complex & d3, complex const& a0, complex const& a1, complex const& a2, complex const& a3, complex const& b0, complex const& b1, complex const& c0, complex const& c1, complex const& c2, complex const& c3) { // Because thrust::complex does not provide a mutable ref double& rd0 = reinterpret_cast(d0)[0]; double& id0 = reinterpret_cast(d0)[1]; double& rd1 = reinterpret_cast(d1)[0]; double& id1 = reinterpret_cast(d1)[1]; double& rd2 = reinterpret_cast(d2)[0]; double& id2 = reinterpret_cast(d2)[1]; double& rd3 = reinterpret_cast(d3)[0]; double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); SM90_16x8x8_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), a2.real(), a3.real(), b0.real(), b1.real(), c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); SM90_16x8x8_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), a2.imag(), a3.imag(), b0.real(), b1.real(), c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); SM90_16x8x8_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), b0.imag(), b1.imag(), d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); SM90_16x8x8_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), a2.real(), a3.real(), b0.imag(), b1.imag(), d0.imag(), d1.imag(), d2.imag(), d3.imag()); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 16x8x16 TN struct SM90_16x8x16_C64C64C64C64_TN { using DRegisters = complex[4]; using ARegisters = complex[8]; using BRegisters = complex[4]; using CRegisters = complex[4]; CUTE_HOST_DEVICE static void fma(complex & d0, complex & d1, complex & d2, complex & d3, complex const& a0, complex const& a1, complex const& a2, complex const& a3, complex const& a4, complex const& a5, complex const& a6, complex const& a7, complex const& b0, complex const& b1, complex const& b2, complex const& b3, complex const& c0, complex const& c1, complex const& c2, complex const& c3) { // Because thrust::complex does not provide a mutable ref double& rd0 = reinterpret_cast(d0)[0]; double& id0 = reinterpret_cast(d0)[1]; double& rd1 = reinterpret_cast(d1)[0]; double& id1 = reinterpret_cast(d1)[1]; double& rd2 = reinterpret_cast(d2)[0]; double& id2 = reinterpret_cast(d2)[1]; double& rd3 = reinterpret_cast(d3)[0]; double& id3 = reinterpret_cast(d3)[1]; // d.real() = a.real() * b.real() + c.real(); SM90_16x8x16_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, a0.real(), a1.real(), a2.real(), a3.real(), a4.real(), a5.real(), a6.real(), a7.real(), b0.real(), b1.real(), b2.real(), b3.real(), c0.real(), c1.real(), c2.real(), c3.real()); // d.imag() = a.imag() * b.real() + c.imag(); SM90_16x8x16_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.imag(), a1.imag(), a2.imag(), a3.imag(), a4.imag(), a5.imag(), a6.imag(), a7.imag(), b0.real(), b1.real(), b2.real(), b3.real(), c0.imag(), c1.imag(), c2.imag(), c3.imag()); // d.real() = -a.imag() * b.imag() + d.real(); SM90_16x8x16_F64F64F64F64_TN::fma( rd0, rd1, rd2, rd3, -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), b0.imag(), b1.imag(), b2.imag(), b3.imag(), d0.real(), d1.real(), d2.real(), d3.real()); // d.imag() = a.real() * b.imag() + d.imag(); SM90_16x8x16_F64F64F64F64_TN::fma( id0, id1, id2, id3, a0.real(), a1.real(), a2.real(), a3.real(), a4.real(), a5.real(), a6.real(), a7.real(), b0.imag(), b1.imag(), b2.imag(), b3.imag(), d0.imag(), d1.imag(), d2.imag(), d3.imag()); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cute { namespace GMMA { template < class ElementA, class ElementB, class ElementC, class TileShape_MNK, GMMA::Major MajorA = GMMA::Major::K, GMMA::Major MajorB = GMMA::Major::K, auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] // But most commonly leave empty for defaults > CUTE_HOST_DEVICE constexpr auto ss_op_selector() { static_assert(is_static::value, "TileShape_MNK must be static."); static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); auto Tile_N = size<1>(TileShape_MNK{}); // FP16 accumulator if constexpr (is_same_v) { if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); // Dispatch against the Tile N mode size if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x16_F16F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x16_F16F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e4m3_t ; Input B: float_e4m3_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F16E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F16E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e4m3_t ; Input B: float_e5m2_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F16E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F16E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e5m2_t ; Input B: float_e5m2_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F16E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F16E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e5m2_t ; Input B: float_e4m3_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F16E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F16E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } else { static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); } } // FP32 accumulator else if constexpr (is_same_v) { // FP16 inputs if constexpr (is_same_v) { static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 240 == 0) { return SM90_64x240x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 224 == 0) { return SM90_64x224x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 176 == 0) { return SM90_64x176x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 160 == 0) { return SM90_64x160x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 144 == 0) { return SM90_64x144x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 112 == 0) { return SM90_64x112x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 80 == 0) { return SM90_64x80x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 48 == 0) { return SM90_64x48x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x16_F32F16F16_SS{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x16_F32F16F16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // BF16 inputs else if constexpr (is_same_v) { static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 240 == 0) { return SM90_64x240x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 224 == 0) { return SM90_64x224x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 176 == 0) { return SM90_64x176x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 160 == 0) { return SM90_64x160x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 144 == 0) { return SM90_64x144x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 112 == 0) { return SM90_64x112x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 80 == 0) { return SM90_64x80x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 48 == 0) { return SM90_64x48x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x16_F32BF16BF16_SS{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x16_F32BF16BF16_SS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // TF32 inputs else if constexpr (is_same_v) { static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x8_F32TF32TF32_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x8_F32TF32TF32_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e4m3_t ; Input B: float_e4m3_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F32E4M3E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F32E4M3E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e4m3_t ; Input B: float_e5m2_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F32E4M3E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F32E4M3E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e5m2_t ; Input B: float_e5m2_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F32E5M2E5M2_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F32E5M2E5M2_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e5m2_t ; Input B: float_e4m3_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F32E5M2E4M3_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F32E5M2E4M3_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } else { static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); } } // S32 accumulator else if constexpr (is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); // ElementA == int8_t && ElementB == int8_t if constexpr (is_same_v && is_same_v) { if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_S32S8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // ElementA == int8_t && ElementB == uint8_t else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_S32S8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // ElementA == uint8_t && ElementB == int8_t else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_S32U8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // ElementA == uint8_t && ElementB == uint8_t else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_S32U8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } } // Unknown accumulator type else { static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); } } template < class ElementA, class ElementB, class ElementC, class TileShape_MNK, GMMA::Major MajorA = GMMA::Major::K, GMMA::Major MajorB = GMMA::Major::K, auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] // But most commonly leave empty for defaults > CUTE_HOST_DEVICE constexpr auto rs_op_selector() { static_assert(is_static::value, "TileShape_MNK must be static."); static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); auto Tile_N = size<1>(TileShape_MNK{}); // FP16 accumulator if constexpr (is_same_v) { static_assert(is_same_v, "Element types for AB must be half if ElementC is half."); static_assert(is_same_v, "Element types for AB must be half if ElementC is half."); static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); // Dispatch against the Tile N mode size if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F16F16F16_RS{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x16_F16F16F16_RS{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x16_F16F16F16_RS{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x16_F16F16F16_RS{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x16_F16F16F16_RS{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x16_F16F16F16_RS{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x16_F16F16F16_RS{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x16_F16F16F16_RS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP32 accumulator else if constexpr (is_same_v) { // FP16 inputs if constexpr (is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 240 == 0) { return SM90_64x240x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 224 == 0) { return SM90_64x224x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 176 == 0) { return SM90_64x176x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 160 == 0) { return SM90_64x160x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 144 == 0) { return SM90_64x144x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 112 == 0) { return SM90_64x112x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 80 == 0) { return SM90_64x80x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 48 == 0) { return SM90_64x48x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x16_F32F16F16_RS{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x16_F32F16F16_RS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // BF16 inputs else if constexpr (is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 240 == 0) { return SM90_64x240x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 224 == 0) { return SM90_64x224x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 176 == 0) { return SM90_64x176x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 160 == 0) { return SM90_64x160x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 144 == 0) { return SM90_64x144x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 112 == 0) { return SM90_64x112x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 80 == 0) { return SM90_64x80x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 48 == 0) { return SM90_64x48x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x16_F32BF16BF16_RS{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x16_F32BF16BF16_RS{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // TF32 inputs else if constexpr (is_same_v) { static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x8_F32TF32TF32_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x8_F32TF32TF32_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x8_F32TF32TF32_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x8_F32TF32TF32_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x8_F32TF32TF32_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x8_F32TF32TF32_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x8_F32TF32TF32_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x8_F32TF32TF32_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e4m3_t ; Input B: float_e4m3_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F32E4M3E4M3_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F32E4M3E4M3_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F32E4M3E4M3_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F32E4M3E4M3_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F32E4M3E4M3_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F32E4M3E4M3_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F32E4M3E4M3_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F32E4M3E4M3_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e4m3_t ; Input B: float_e5m2_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F32E4M3E5M2_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F32E4M3E5M2_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F32E4M3E5M2_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F32E4M3E5M2_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F32E4M3E5M2_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F32E4M3E5M2_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F32E4M3E5M2_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F32E4M3E5M2_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e5m2_t ; Input B: float_e5m2_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F32E5M2E5M2_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F32E5M2E5M2_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F32E5M2E5M2_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F32E5M2E5M2_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F32E5M2E5M2_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F32E5M2E5M2_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F32E5M2E5M2_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F32E5M2E5M2_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // FP8 // Input A: float_e5m2_t ; Input B: float_e4m3_t else if constexpr (is_same_v && is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_F32E5M2E4M3_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_F32E5M2E4M3_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_F32E5M2E4M3_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_F32E5M2E4M3_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_F32E5M2E4M3_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_F32E5M2E4M3_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_F32E5M2E4M3_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_F32E5M2E4M3_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } else { static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); } } // S32 accumulator else if constexpr (is_same_v) { static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); // ElementA == int8_t && ElementB == int8_t if constexpr (is_same_v && is_same_v) { if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_S32S8S8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // ElementA == int8_t && ElementB == uint8_t else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_S32S8U8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // ElementA == uint8_t && ElementB == int8_t else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_S32U8S8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } // ElementA == uint8_t && ElementB == uint8_t else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { return SM90_64x192x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { return SM90_64x128x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { return SM90_64x96x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { return SM90_64x64x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { return SM90_64x32x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { return SM90_64x16x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { return SM90_64x8x32_S32U8U8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } } // Unknown accumulator type else { static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); } } } // end namespace GMMA } // end namespace cute ////////////////////////////////////////////////////////////////////////////////////////////////////