1x1x1 cluster launch (#1673)
This commit is contained in:
parent
eee0cab26c
commit
06b21349bc
@ -296,7 +296,9 @@ public:
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(ConvKernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
constexpr bool is_static_1x1x1 = cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape> and
|
||||
cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1;
|
||||
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
||||
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
||||
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
|
||||
@ -324,8 +326,14 @@ public:
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
void const* kernel = (void const*) device_kernel<ConvKernel>;
|
||||
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) {
|
||||
launch_result = ClusterLauncher::launch(
|
||||
grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
if constexpr (is_static_1x1x1) {
|
||||
device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params);
|
||||
launch_result = Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
launch_result = ClusterLauncher::launch(
|
||||
grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -350,6 +350,8 @@ public:
|
||||
Status launch_result{ Status::kSuccess };
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
constexpr bool is_static_1x1x1 = cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape> and
|
||||
cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1;
|
||||
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
|
||||
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
|
||||
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
|
||||
@ -383,8 +385,13 @@ public:
|
||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||
void const* kernel = (void const*) device_kernel<GemmKernel>;
|
||||
if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) {
|
||||
launch_result = ClusterLauncher::launch(
|
||||
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
|
||||
if (is_static_1x1x1 && not launch_with_pdl) {
|
||||
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
|
||||
}
|
||||
else {
|
||||
launch_result = ClusterLauncher::launch(
|
||||
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user