diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 234c2c8a..70247e94 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -13,7 +13,7 @@ from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.utils import FlexibleArgumentParser -DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 6ce25c5a..d26c43de 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -1,470 +1,16 @@ #include #include - -#include - -// clang-format will break include orders -// clang-format off -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/util/device_memory.h" - #include "cutlass/cutlass.h" -#include "cutlass/gemm_coord.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" -#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" - -#include "broadcast_load_epilogue_c2x.hpp" -#include "common.hpp" -// clang-format on - -using namespace cute; +#include "scaled_mm_c2x.cuh" +#include "scaled_mm_c2x_sm80_dispatch.cuh" +#include "scaled_mm_c2x_sm89_dispatch.cuh" /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm80EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. */ -namespace { - -// Wrappers for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm75_to_sm80 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm80_to_sm89 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -template -struct enable_sm89_to_sm90 : Kernel { - template - CUTLASS_DEVICE static void invoke(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900 - Kernel::invoke(std::forward(args)...); -#endif - } -}; - -/* - * This class provides the common ScaleA and ScaleB descriptors for the - * ScaledEpilogue and ScaledEpilogueBias classes. - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; - - using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::ScaleA; - using ScaleB = typename SUPER::ScaleB; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - using ScaleAArgs = typename ScaleA::Arguments; - using ScaleBArgs = typename ScaleB::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - - typename EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args}; - return evt_compute_args; - } -}; - -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::ScaleA; - using ScaleB = typename SUPER::ScaleB; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, ElementD, Stride, Int<1>, Int<0>>>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - using ScaleAArgs = typename ScaleA::Arguments; - using ScaleBArgs = typename ScaleB::Arguments; - using BiasArgs = typename Bias::Arguments; - - ScaleBArgs b_args{b_scales.data_ptr(), b_scales.numel() != 1, {}}; - ScaleAArgs a_args{a_scales.data_ptr(), a_scales.numel() != 1, {}}; - BiasArgs bias_args{static_cast(bias.data_ptr()), {}}; - - typename EVTCompute0::Arguments evt0_compute_args{b_args}; - - typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args, - bias_args}; - return evt_compute_args; - } -}; - -template typename ArchGuard, - typename ElementAB_, typename ElementD_, - template typename Epilogue_, typename TileShape, - typename WarpShape, typename InstructionShape, int32_t MainLoopStages> -struct cutlass_2x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using Operator = - typename std::conditional, - cutlass::arch::OpMultiplyAddSaturate, - cutlass::arch::OpMultiplyAdd>::type; - - using OutputTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - TileShape, WarpShape, float, 4, 1 /* epilogue stages */ - >; - - using Epilogue = Epilogue_; - using EVTCompute = typename Epilogue::EVTCompute; - - using D = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, - Stride, Int<0>>>; - - using EVTD = cutlass::epilogue::threadblock::Sm80EVT; - - // clang-format off - using RowMajor = typename cutlass::layout::RowMajor; - using ColumnMajor = typename cutlass::layout::ColumnMajor; - using KernelType = - ArchGuard::GemmKernel>; - // clang-format on - - using Op = cutlass::gemm::device::GemmUniversalAdapter; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - cutlass::gemm::GemmCoord problem_size{m, n, k}; - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideC = Stride, Int<0>>; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto c_ptr = static_cast(out.data_ptr()); - - typename Gemm::D::Arguments d_args{c_ptr, c_stride}; - - using Epilogue = typename Gemm::Epilogue; - auto evt_args = - Epilogue::prepare_args(std::forward(epilogue_params)...); - - typename Gemm::EVTD::Arguments epilogue_args{ - evt_args, - d_args, - }; - - typename Gemm::Op::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode - problem_size, // problem size - 1, // batch count - epilogue_args, - a_ptr, - b_ptr, - nullptr, - nullptr, - 0, - 0, - 0, - 0, - lda, - ldb, - ldc, - ldc}; - - // Launch the CUTLASS GEMM kernel. - typename Gemm::Op gemm_op; - size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get(), stream); - CUTLASS_CHECK(status); -} - -template -void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - // In some cases, the GPU isn't able to accommodate the - // shared memory requirements of the Gemm. In such cases, use - // the FallbackGemm instead. - static const int max_shared_mem_per_block_opt_in = - get_cuda_max_shared_memory_per_block_opt_in(0); - - size_t const gemm_shared_mem_size = - sizeof(typename Gemm::KernelType::SharedStorage); - size_t const fallback_gemm_shared_mem_size = - sizeof(typename FallbackGemm::KernelType::SharedStorage); - - if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { - return cutlass_gemm_caller(out, a, b, - std::forward(args)...); - } else { - TORCH_CHECK(fallback_gemm_shared_mem_size <= - max_shared_mem_per_block_opt_in); - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template typename Epilogue> -struct sm80_config_default { - // This config is used in 2 cases, - // - M in (128, inf) - // - M in (64, 128] and N >= 8192 - // Shared Memory required by this Gemm - 81920 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M64 { - // This config is used in 2 cases, - // - M in (32, 64] - // - M in (64, 128] and N < 8192 - // Shared Memory required by this Gemm - 122880 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M32 { - // M in (16, 32] - // Shared Memory required by this Gemm - 61440 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -template typename Epilogue> -struct sm80_config_M16 { - // M in [1, 16] - // Shared Memory required by this Gemm - 51200 bytes - static_assert(std::is_same()); - using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; - using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - using Cutlass2xGemm = - cutlass_2x_gemm; -}; - -} // namespace - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); - - using Cutlass2xGemmDefault = - typename sm80_config_default::Cutlass2xGemm; - using Cutlass2xGemmM128BigN = - typename sm80_config_default::Cutlass2xGemm; - using Cutlass2xGemmM128SmallN = - typename sm80_config_M64::Cutlass2xGemm; - using Cutlass2xGemmM64 = - typename sm80_config_M64::Cutlass2xGemm; - using Cutlass2xGemmM32 = - typename sm80_config_M32::Cutlass2xGemm; - using Cutlass2xGemmM16 = - typename sm80_config_M16::Cutlass2xGemm; - - // Due to shared memory requirements, some Gemms may fail to run on some - // GPUs. As the name indicates, the Fallback Gemm is used as an alternative - // in such cases. - // sm80_config_M16 has the least shared-memory requirement. However, - // based on some profiling, we select sm80_config_M32 as a better alternative - // performance wise. - using FallbackGemm = - typename sm80_config_M32::Cutlass2xGemm; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(16), next_pow_2(m)); // next power of 2 - if (mp2 <= 16) { - // M in [1, 16] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 32) { - // M in (16, 32] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 64) { - // M in (32, 64] - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // M in (64, 128] - uint32_t const n = out.size(1); - bool const small_n = n < 8192; - if (small_n) { - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } - } else { - // M in (128, inf) - return fallback_cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - template