diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py new file mode 100644 index 00000000..6de56f61 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -0,0 +1,352 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + +# helpers + + +def to_fp8(tensor: torch.tensor) -> torch.tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.tensor) -> torch.tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.tensor, torch.tensor]: + + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +# impl + + +def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch.mm(a, b) + + +def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch._scaled_mm(a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype) + + +def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, + scale_a: torch.tensor, scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch._scaled_mm(a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + use_fast_accum=True) + + +def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return ops.cutlass_scaled_mm_dq(a, + b, + scale_a, + scale_b, + out_dtype=out_dtype) + + +# bench +def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, out_dtype: torch.dtype, label: str, + sub_label: str, fn: Callable, description: str) -> TMeasurement: + + min_run_time = 1 + + globals = { + "a": a, + "b": b, + "scale_a": scale_a, + "scale_b": scale_b, + "out_dtype": out_dtype, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(a, b, scale_a, scale_b, out_dtype)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + a, b = make_rand_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + timers = [] + # pytorch impl + timers.append( + bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, + torch.bfloat16, label, sub_label, pytorch_i8_impl, + "pytorch_bf16_bf16_bf16_matmul-no-scales")) + + # cutlass impl + timers.append( + bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), + torch.bfloat16, label, sub_label, cutlass_impl, + "cutlass_i8_i8_bf16_scaled_mm")) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + timers = [] + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + pytorch_fp8_impl_fast_accum, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, + pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, + pytorch_fp8_impl_fast_accum, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) + + # cutlass impl: bf16 output + timers.append( + bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), + torch.bfloat16, label, sub_label, cutlass_impl, + "cutlass_fp8_fp8_bf16_scaled_mm")) + # cutlass impl: fp16 output + timers.append( + bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), + torch.float16, label, sub_label, cutlass_impl, + "cutlass_fp8_fp8_fp16_scaled_mm")) + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = argparse.ArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py new file mode 100644 index 00000000..7ad4a53d --- /dev/null +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -0,0 +1,37 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 2383760a..4c1aec03 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -51,6 +51,11 @@ using namespace cute; namespace { +uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + template @@ -188,8 +193,89 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, cutlass::Status status = gemm_op.run(args, workspace.get(), stream); CUTLASS_CHECK(status); } + +template +struct sm90_fp8_config { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template +struct sm90_fp8_config { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template +struct sm90_fp8_config { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + } // namespace +template +void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using Cutlass3xGemmDefault = + typename sm90_fp8_config::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config::Cutlass3xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return cutlass_scaled_mm_dq_dispatcher( + out, a, b, a_scales, b_scales); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_scaled_mm_dq_dispatcher( + out, a, b, a_scales, b_scales); + } else { + // m in (128, inf) + return cutlass_scaled_mm_dq_dispatcher( + out, a, b, a_scales, b_scales); + } +} + void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -223,24 +309,14 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _2, _1>; - using KernelSchedule = - typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative; - using EpilogueSchedule = - typename cutlass::epilogue::TmaWarpSpecializedCooperative; - if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher< - cutlass_3x_gemm>( + return cutlass_scaled_mm_dq_sm90_fp8_dispatch( out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - - return cutlass_scaled_mm_dq_dispatcher< - cutlass_3x_gemm>( + return cutlass_scaled_mm_dq_sm90_fp8_dispatch( out, a, b, a_scales, b_scales); } } diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 5a18dd5c..079d9650 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -82,7 +82,7 @@ def cutlass_int8_gemm_helper(m: int, assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) -@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("m", [512, 222, 100, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("per_act_token", [True, False])