1x1x1 cluster launch (#1673)
This commit is contained in:
		
							parent
							
								
									eee0cab26c
								
							
						
					
					
						commit
						06b21349bc
					
				| @ -296,7 +296,9 @@ public: | ||||
| 
 | ||||
|     Status launch_result; | ||||
|     // Use extended launch API only for mainloops that use it
 | ||||
|     if constexpr(ConvKernel::ArchTag::kMinComputeCapability >= 90) { | ||||
|     if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) { | ||||
|       constexpr bool is_static_1x1x1 = cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape> and | ||||
|                                        cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; | ||||
|       dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), | ||||
|                    cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), | ||||
|                    cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); | ||||
| @ -324,8 +326,14 @@ public: | ||||
|         CUTLASS_ASSERT(cuda_adapter == nullptr); | ||||
|         void const* kernel = (void const*) device_kernel<ConvKernel>; | ||||
|         if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) { | ||||
|           launch_result = ClusterLauncher::launch( | ||||
|               grid, cluster, block, smem_size, stream, kernel, kernel_params); | ||||
|           if constexpr (is_static_1x1x1) { | ||||
|             device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params); | ||||
|             launch_result = Status::kSuccess; | ||||
|           } | ||||
|           else { | ||||
|             launch_result = ClusterLauncher::launch( | ||||
|                 grid, cluster, block, smem_size, stream, kernel, kernel_params); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
| @ -350,6 +350,8 @@ public: | ||||
|     Status launch_result{ Status::kSuccess }; | ||||
|     // Use extended launch API only for mainloops that use it
 | ||||
|     if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { | ||||
|       constexpr bool is_static_1x1x1 = cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape> and | ||||
|                                        cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; | ||||
|       dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), | ||||
|                    cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), | ||||
|                    cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); | ||||
| @ -383,8 +385,13 @@ public: | ||||
|         CUTLASS_ASSERT(cuda_adapter == nullptr); | ||||
|         void const* kernel = (void const*) device_kernel<GemmKernel>; | ||||
|         if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { | ||||
|           launch_result = ClusterLauncher::launch( | ||||
|             grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); | ||||
|           if (is_static_1x1x1 && not launch_with_pdl) { | ||||
|             device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params); | ||||
|           } | ||||
|           else { | ||||
|             launch_result = ClusterLauncher::launch( | ||||
|               grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 dePaul Miller
						dePaul Miller