diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 2477051e..b9c5d392 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -17,6 +17,15 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { return old; } +#define FP8_E4M3_MAX std::numeric_limits::max() + +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { + float x = static_cast(val) / scale; + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + // Compute the absolute maximum m of the input tensor and store // m / float8_e4m3::max() in *scale. Each thread block performs a // reduction tree and the memory in scale is atomically updated. @@ -67,7 +76,7 @@ __global__ void scaled_fp8_quant_kernel( int64_t num_elems) { int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { - out[i] = static_cast(input[i] / *scale); + out[i] = scaled_fp8_conversion(input[i], *scale); i += blockDim.x * gridDim.x; } }