/*************************************************************************************************** * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/gemm/collective/collective_builder.hpp" #include "dispatch_policy_extra.hpp" #include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp" namespace cutlass::gemm::collective { // GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch template < class ElementA, class GmemLayoutATag, int AlignmentA, class ElementB, class GmemLayoutBTag, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType, class KernelScheduleType > struct CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutATag, AlignmentA, ElementB, GmemLayoutBTag, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType, cute::enable_if_t< cute::is_same_v> > { static_assert(is_static::value); static_assert(is_static::value); static_assert(detail::is_aligned(), "Not meet TMA alignment requirement yet\n"); static_assert(detail::is_input_fp8(), "Only FP8 datatypes are compatible with these kernel schedules\n"); // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder static_assert(!detail::is_use_rmem_A(), "Not supported for fp8 non-TN warp specialized kernels yet\n"); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); using AtomLayoutMNK = Layout>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); using SmemLayoutAtomA = decltype(detail::ss_smem_selector< GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; using SmemCopyAtomA = void; using SmemCopyAtomB = void; using CollectiveOp = CollectiveMma< DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity >; }; // GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps template < class ElementA, class GmemLayoutATag, int AlignmentA, class ElementB, class GmemLayoutBTag, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType, class KernelScheduleType > struct CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutATag, AlignmentA, ElementB, GmemLayoutBTag, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType, cute::enable_if_t< cute::is_same_v> > { static_assert(is_static::value); static_assert(is_static::value); static_assert(detail::is_aligned(), "Not meet TMA alignment requirement yet\n"); static_assert(detail::is_input_fp8(), "Only FP8 datatypes are compatible with these kernel schedules\n"); // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder static_assert(!detail::is_use_rmem_A(), "Not supported for fp8 non-TN warp specialized kernels yet\n"); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); using AtomLayoutMNK = Layout>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); using SmemLayoutAtomA = decltype(detail::ss_smem_selector< GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; using SmemCopyAtomA = void; using SmemCopyAtomB = void; using CollectiveOp = CollectiveMma< DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity >; }; } // namespace cutlass::gemm::collective /////////////////////////////////////////////////////////////////////////////////////////////////