From 35e9c12bfaf8f273281af897b7208dfba53f103c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 31 Jul 2024 17:40:32 -0400 Subject: [PATCH] [Kernel] Tuned int8 Cutlass Kernels for SM75 (T4) (#6996) Co-authored-by: Varun Sundar Rabindranath --- .../cutlass_benchmarks/w8a8_benchmarks.py | 9 +- .../cutlass_w8a8/scaled_mm_c2x.cu | 15 +-- .../scaled_mm_c2x_sm75_dispatch.cuh | 123 ++++++++++++++++++ 3 files changed, 135 insertions(+), 12 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 70247e94..64011b2d 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) timers = [] - # pytorch impl + # pytorch impl - bfloat16 timers.append( bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, torch.bfloat16, label, sub_label, pytorch_mm_impl, "pytorch_bf16_bf16_bf16_matmul-no-scales")) + # pytorch impl - float16 + timers.append( + bench_fn(a.to(dtype=torch.float16, device="cuda"), + b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b, + torch.float16, label, sub_label, pytorch_mm_impl, + "pytorch_fp16_fp16_fp16_matmul-no-scales")) + # cutlass impl timers.append( bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index aac4900f..8d0dfee7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -3,6 +3,7 @@ #include "cutlass/cutlass.h" #include "scaled_mm_c2x.cuh" +#include "scaled_mm_c2x_sm75_dispatch.cuh" #include "scaled_mm_c2x_sm80_dispatch.cuh" #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh" @@ -20,21 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); - using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; - if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_caller< - vllm::cutlass_2x_gemm>( + return vllm::cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_caller>( + return vllm::cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh new file mode 100644 index 00000000..a562fd89 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh @@ -0,0 +1,123 @@ +#pragma once + +#include "scaled_mm_c2x.cuh" + +/** + * This file defines Gemm kernel configurations for SM75 based on the Gemm + * shape. + */ + +namespace vllm { + +template typename Epilogue> +struct sm75_config_default { + // This config is used in 2 cases, + // - M in (256, inf] + // - M in (64, 128] + // Shared memory required by this Gemm 32768 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm75_config_M256 { + // M in (128, 256] + // Shared memory required by this Gemm 65536 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm75_config_M64 { + // M in (32, 64] + // Shared memory required by this Gemm 49152 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm75_config_M32 { + // M in [1, 32] + // Shared memory required by this Gemm 49152 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue, + typename... EpilogueArgs> +inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + + using Cutlass2xGemmDefault = + typename sm75_config_default::Cutlass2xGemm; + using Cutlass2xGemmM256 = + typename sm75_config_M256::Cutlass2xGemm; + using Cutlass2xGemmM128 = Cutlass2xGemmDefault; + using Cutlass2xGemmM64 = + typename sm75_config_M64::Cutlass2xGemm; + using Cutlass2xGemmM32 = + typename sm75_config_M32::Cutlass2xGemm; + + // Due to shared memory requirements, some Gemms may fail to run on some + // GPUs. As the name indicates, the Fallback Gemm is used as an alternative + // in such cases. + // sm75_config_default has the least shared-memory requirements. + using FallbackGemm = Cutlass2xGemmDefault; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + if (mp2 <= 32) { + // M in [1, 32] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 64) { + // M in (32, 64] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // M in (64, 128] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 256) { + // M in (128, 256] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + // M in (256, inf) + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + +} // namespace vllm