torch_ext/csrc/quantize.cu
2025-03-27 03:44:28 +08:00

36 lines
1.1 KiB
Plaintext

#include <cuda_fp16.h>
#include <cuda_fp8.h>
__global__ void quantize(const half *src, __nv_fp8_storage_t *dest, int x_len, int y_len)
{
int x_start = threadIdx.x * blockDim.x;
int y_start = threadIdx.y * blockDim.y;
__shared__ half max_value;
max_value = __float2half(-10000.0f);
for (int i = 0; i < blockDim.x; i++)
{
for (int j = 0; j < blockDim.x; j++)
{
if (x_start + i < x_len && y_start + j < y_len)
{
int real_offset = (y_start + j) * x_len + x_start + i;
max_value = __hmax(src[real_offset], max_value);
}
}
}
for (int i = 0; i < blockDim.x; i++)
{
for (int j = 0; j < blockDim.y; j++)
{
if (x_start + i < x_len && y_start + j < y_len)
{
int real_offset = (y_start + j) * x_len + x_start + i;
half tmp = __hdiv(src[real_offset], max_value);
dest[real_offset] = __nv_cvt_halfraw_to_fp8(__nv_half_raw(tmp), __NV_SATFINITE, __NV_E5M2);
}
}
}
}