diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 8c5b693b..6120086d 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -23,8 +23,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { template __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( - const scalar_t val, const float scale) { - float x = static_cast(val) / scale; + const scalar_t val, const float inverted_scale) { + float x = static_cast(val) * inverted_scale; float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); return static_cast(r); } @@ -71,15 +71,56 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, } } +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +typedef struct __align__(4) { + c10::Float8_e4m3fn x; + c10::Float8_e4m3fn y; + c10::Float8_e4m3fn z; + c10::Float8_e4m3fn w; +} +float8x4_t; + template __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, const scalar_t* __restrict__ input, const float* __restrict__ scale, int64_t num_elems) { - int i = blockDim.x * blockIdx.x + threadIdx.x; - while (i < num_elems) { - out[i] = scaled_fp8_conversion(input[i], *scale); - i += blockDim.x * gridDim.x; + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + // Invert the scale so that we can use multiplications to avoid expensive + // division. + const float inverted_scale = 1.0f / (*scale); + + // Vectorized input/output to better utilize memory bandwidth. + const vec4_t* vectorized_in = + reinterpret_cast*>(input); + float8x4_t* vectorized_out = reinterpret_cast(out); + + int num_vec_elems = num_elems >> 2; + +#pragma unroll 4 + for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) { + vec4_t in_vec = vectorized_in[i]; + float8x4_t out_vec; + + out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale); + out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale); + out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale); + out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale); + vectorized_out[i] = out_vec; + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int i = num_vec_elems * 4 + tid; i < num_elems; + i += blockDim.x * gridDim.x) { + out[i] = scaled_fp8_conversion(input[i], inverted_scale); } } diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index fccce7f7..7cb65326 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -5,6 +5,7 @@ Run `pytest tests/quantization/test_fp8.py --forked`. import pytest import torch +from vllm._custom_ops import scaled_fp8_quant from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod @@ -22,3 +23,49 @@ def test_load_fp16_model(vllm_runner) -> None: fc1 = model.model.decoder.layers[0].fc1 assert isinstance(fc1.quant_method, Fp8LinearMethod) assert fc1.weight.dtype == torch.float8_e4m3fn + + +@pytest.mark.skipif( + capability < QUANTIZATION_METHODS["fp8"].get_min_capability(), + reason="FP8 is not supported on this GPU type.") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_scaled_fp8_quant(dtype) -> None: + + def quantize_ref(tensor, inv_scale): + # The reference implementation that fully aligns to + # the kernel being tested. + finfo = torch.finfo(torch.float8_e4m3fn) + scale = inv_scale.reciprocal() + qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, + max=finfo.max) + qweight = qweight.to(torch.float8_e4m3fn) + return qweight + + def per_tensor_dequantize(tensor, inv_scale, dtype): + fake_qweight = tensor.to(dtype) + dq_weight = fake_qweight * inv_scale + return dq_weight + + # Note that we use a shape % 4 != 0 to cover edge cases, + # because scaled_fp8_quant is vectorized by 4. + x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype) + + # Dynamic quantization + ref_y, inv_scale = scaled_fp8_quant(x, None) + ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype) + + # Reference dynamic quantizaton + y = quantize_ref(x, inv_scale) + assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) + + # Static quantization + y, _ = scaled_fp8_quant(x, inv_scale) + assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) + + # Padding + y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17) + assert y.shape[0] == 17 + assert torch.allclose( + ref_y, + per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, + dtype))