Skip void-C kernels in the profiler when beta is non zero (#1661)
* Skip void-C kernels in the profiler when beta is non zero CUTLASS profiler will only skip disposition for void-C kernels when beta is non zero, when it makes more sense to skip running it in the first place. Not all users are aware of void-C kernels (as far as I know it wasn't a thing in 2.X), and not everyone remembers to filter out voidC kernels when running the profiler with a non zero beta. The easiest solution (and as far as I can tell correct way of handling this) is that `can_implement` return `false` when beta is non zero (or whatever argument indicates an epilogue source) but we have a void-C kernel. Profiler already includes functionality to skip running kernels that fail `can_implement`. * Move checks to collectives instead --------- Co-authored-by: Ali Hassani <ahassani@nvidia.com>
This commit is contained in:
parent
8b2a0408bd
commit
1f2b590da6
@ -392,6 +392,27 @@ public:
|
||||
tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { }
|
||||
};
|
||||
|
||||
// SFINAE helpers for detecting beta/beta_ptr in EVT arguments.
|
||||
template <class Arguments, class = void>
|
||||
struct has_beta {
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <class Arguments>
|
||||
struct has_beta<Arguments, cute::void_t<decltype(Arguments{}.thread.beta)>> {
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <class Arguments, class = void>
|
||||
struct has_beta_ptr {
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <class Arguments>
|
||||
struct has_beta_ptr<Arguments, cute::void_t<decltype(Arguments{}.thread.beta_ptr)>> {
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace collective
|
||||
} // namespace epilogue
|
||||
|
@ -369,7 +369,23 @@ public:
|
||||
if (!fusion_implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
|
||||
}
|
||||
return implementable && fusion_implementable;
|
||||
|
||||
bool beta_implementable = true;
|
||||
|
||||
if constexpr (cute::is_void_v<ElementC>) {
|
||||
if constexpr (detail::has_beta<Arguments>::value) {
|
||||
beta_implementable = args.thread.beta == 0.0;
|
||||
}
|
||||
if constexpr (detail::has_beta_ptr<Arguments>::value) {
|
||||
beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (!beta_implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n");
|
||||
}
|
||||
|
||||
return implementable && fusion_implementable && beta_implementable;
|
||||
}
|
||||
|
||||
template<class TileShapeMNK>
|
||||
|
@ -339,7 +339,23 @@ public:
|
||||
if (!fusion_implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
|
||||
}
|
||||
return implementable && fusion_implementable;
|
||||
|
||||
bool beta_implementable = true;
|
||||
|
||||
if constexpr (cute::is_void_v<ElementC>) {
|
||||
if constexpr (detail::has_beta<Arguments>::value) {
|
||||
beta_implementable = args.thread.beta == 0.0;
|
||||
}
|
||||
if constexpr (detail::has_beta_ptr<Arguments>::value) {
|
||||
beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (!beta_implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n");
|
||||
}
|
||||
|
||||
return implementable && fusion_implementable && beta_implementable;
|
||||
}
|
||||
|
||||
template<class TileShapeMNK>
|
||||
|
Loading…
Reference in New Issue
Block a user