1x1x1 cluster launch (#1673)

This commit is contained in:
dePaul Miller 2024-08-01 09:20:28 -07:00 committed by GitHub
parent eee0cab26c
commit 06b21349bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 5 deletions

View File

@ -296,7 +296,9 @@ public:
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(ConvKernel::ArchTag::kMinComputeCapability >= 90) {
if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) {
constexpr bool is_static_1x1x1 = cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape> and
cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1;
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
@ -324,8 +326,14 @@ public:
CUTLASS_ASSERT(cuda_adapter == nullptr);
void const* kernel = (void const*) device_kernel<ConvKernel>;
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params);
if constexpr (is_static_1x1x1) {
device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params);
launch_result = Status::kSuccess;
}
else {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
}
}
}

View File

@ -350,6 +350,8 @@ public:
Status launch_result{ Status::kSuccess };
// Use extended launch API only for mainloops that use it
if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) {
constexpr bool is_static_1x1x1 = cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape> and
cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1;
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
@ -383,8 +385,13 @@ public:
CUTLASS_ASSERT(cuda_adapter == nullptr);
void const* kernel = (void const*) device_kernel<GemmKernel>;
if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
if (is_static_1x1x1 && not launch_with_pdl) {
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
}
else {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
}
}
}
}