From a98187cf7227695819e199e2e3ad35be0a9a84f3 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 6 May 2024 17:39:28 -0700 Subject: [PATCH] [Kernel] Make static FP8 scaling more robust (#4570) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously FP8 static scaling works if the scales are overestimating the maxima of all activation tensors during computation. However this will not always be the case even if the scales were calibrated very carefully. For example, with the activations in my checkpoint https://huggingface.co/pcmoritz/Mixtral-8x7B-v0.1-fp8-act-scale (which was calibrated on https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), I'm getting the following mostly random performance on MMLU: | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.2295|± |0.0035| | - humanities |N/A |none | 5|acc |0.2421|± |0.0062| | - other |N/A |none | 5|acc |0.2398|± |0.0076| | - social_sciences|N/A |none | 5|acc |0.2171|± |0.0074| | - stem |N/A |none | 5|acc |0.2125|± |0.0073| With the fix in this PR where the scaled activations are clamped between [-std::numeric_limits::max(), std::numeric_limits::max()] to make sure there are no NaNs, the performance is | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7008|± |0.0036| | - humanities |N/A |none | 5|acc |0.6453|± |0.0065| | - other |N/A |none | 5|acc |0.7692|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8083|± |0.0070| | - stem |N/A |none | 5|acc |0.6115|± |0.0083| This is not perfect yet but is getting very close to the FP16 / dynamic activation scale performance. --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 2477051e..b9c5d392 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -17,6 +17,15 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { return old; } +#define FP8_E4M3_MAX std::numeric_limits::max() + +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { + float x = static_cast(val) / scale; + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + // Compute the absolute maximum m of the input tensor and store // m / float8_e4m3::max() in *scale. Each thread block performs a // reduction tree and the memory in scale is atomically updated. @@ -67,7 +76,7 @@ __global__ void scaled_fp8_quant_kernel( int64_t num_elems) { int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { - out[i] = static_cast(input[i] / *scale); + out[i] = scaled_fp8_conversion(input[i], *scale); i += blockDim.x * gridDim.x; } }