diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index a0178144..a6e5e2f4 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -392,6 +392,27 @@ public: tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } }; +// SFINAE helpers for detecting beta/beta_ptr in EVT arguments. +template +struct has_beta { + static constexpr bool value = false; +}; + +template +struct has_beta> { + static constexpr bool value = true; +}; + +template +struct has_beta_ptr { + static constexpr bool value = false; +}; + +template +struct has_beta_ptr> { + static constexpr bool value = true; +}; + } // namespace detail } // namespace collective } // namespace epilogue diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 981ea3e2..87e62887 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -369,7 +369,23 @@ public: if (!fusion_implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); } - return implementable && fusion_implementable; + + bool beta_implementable = true; + + if constexpr (cute::is_void_v) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; } template diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index c03aed33..56b55292 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -339,7 +339,23 @@ public: if (!fusion_implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); } - return implementable && fusion_implementable; + + bool beta_implementable = true; + + if constexpr (cute::is_void_v) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; } template