1x1x1 cluster launch (#1673)
This commit is contained in:
parent
eee0cab26c
commit
06b21349bc
@ -297,6 +297,8 @@ public:
|
|||||||
Status launch_result;
|
Status launch_result;
|
||||||
// Use extended launch API only for mainloops that use it
|
// Use extended launch API only for mainloops that use it
|
||||||
if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) {
|
if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) {
|
||||||
|
constexpr bool is_static_1x1x1 = cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape> and
|
||||||
|
cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1;
|
||||||
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
||||||
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
|
||||||
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
|
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
|
||||||
@ -324,11 +326,17 @@ public:
|
|||||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||||
void const* kernel = (void const*) device_kernel<ConvKernel>;
|
void const* kernel = (void const*) device_kernel<ConvKernel>;
|
||||||
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) {
|
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) {
|
||||||
|
if constexpr (is_static_1x1x1) {
|
||||||
|
device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params);
|
||||||
|
launch_result = Status::kSuccess;
|
||||||
|
}
|
||||||
|
else {
|
||||||
launch_result = ClusterLauncher::launch(
|
launch_result = ClusterLauncher::launch(
|
||||||
grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
launch_result = Status::kSuccess;
|
launch_result = Status::kSuccess;
|
||||||
|
|
||||||
|
@ -350,6 +350,8 @@ public:
|
|||||||
Status launch_result{ Status::kSuccess };
|
Status launch_result{ Status::kSuccess };
|
||||||
// Use extended launch API only for mainloops that use it
|
// Use extended launch API only for mainloops that use it
|
||||||
if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) {
|
if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) {
|
||||||
|
constexpr bool is_static_1x1x1 = cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape> and
|
||||||
|
cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1;
|
||||||
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
|
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
|
||||||
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
|
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
|
||||||
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
|
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
|
||||||
@ -383,11 +385,16 @@ public:
|
|||||||
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
||||||
void const* kernel = (void const*) device_kernel<GemmKernel>;
|
void const* kernel = (void const*) device_kernel<GemmKernel>;
|
||||||
if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) {
|
if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) {
|
||||||
|
if (is_static_1x1x1 && not launch_with_pdl) {
|
||||||
|
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
|
||||||
|
}
|
||||||
|
else {
|
||||||
launch_result = ClusterLauncher::launch(
|
launch_result = ClusterLauncher::launch(
|
||||||
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
|
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
launch_result = Status::kSuccess;
|
launch_result = Status::kSuccess;
|
||||||
if constexpr (kEnableCudaHostAdapter) {
|
if constexpr (kEnableCudaHostAdapter) {
|
||||||
|
Loading…
Reference in New Issue
Block a user