diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index 65870df0..088fee47 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -48,9 +48,44 @@ using namespace cute; namespace { -template +// Wrappers for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm75_to_sm80 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm80_to_sm89 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm89_to_sm90 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template typename ArchGuard, + typename ElementAB_, typename ElementD_, typename TileShape, + typename WarpShape, typename InstructionShape, int32_t MainLoopStages> struct cutlass_2x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; @@ -101,7 +136,7 @@ struct cutlass_2x_gemm { using RowMajor = typename cutlass::layout::RowMajor; using ColumnMajor = typename cutlass::layout::ColumnMajor; using KernelType = - typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ArchGuard::GemmKernel; + >::GemmKernel>; // clang-format on using Op = cutlass::gemm::device::GemmUniversalAdapter; @@ -208,16 +243,16 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher< - cutlass_2x_gemm>( - out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_dispatcher< - cutlass_2x_gemm>(out, a, b, a_scales, - b_scales); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); } } @@ -235,16 +270,16 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher< - cutlass_2x_gemm>( - out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_dispatcher< - cutlass_2x_gemm>(out, a, b, a_scales, - b_scales); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); } } @@ -263,16 +298,16 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher< - cutlass_2x_gemm>( - out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); } else { assert(out.dtype() == torch::kFloat16); - return cutlass_scaled_mm_dq_dispatcher< - cutlass_2x_gemm>( - out, a, b, a_scales, b_scales); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); } } else { TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); @@ -280,15 +315,15 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, if (out.dtype() == torch::kBFloat16) { return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, - b_scales); + cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t, + cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>( + out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, - b_scales); + cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t, + cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>( + out, a, b, a_scales, b_scales); } } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 4c1aec03..8fc4ba66 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -56,6 +56,21 @@ uint32_t next_pow_2(uint32_t const num) { return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } +// A wrapper for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { + #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); + #endif + } +}; + template @@ -126,9 +141,9 @@ struct cutlass_3x_gemm { KernelSchedule>::CollectiveOp; // clang-format on - using KernelType = cutlass::gemm::kernel::GemmUniversal< + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; + cutlass::gemm::PersistentScheduler>>; struct GemmKernel : public KernelType {}; };