简单修改一下。

This commit is contained in:
long0x0 2024-12-29 15:49:53 +08:00
parent 0a6b5493fa
commit 80d7be70a5
2 changed files with 36 additions and 0 deletions

1
csrc/fp8_vec.cu Normal file
View File

@ -0,0 +1 @@
#include <cuda_fp8.h>

35
csrc/type_utils.h Normal file
View File

@ -0,0 +1,35 @@
#ifndef TYPE_UTILS_H
#define TYPE_UTILS_H
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#define FP16 __half
#define BF16 __nv_bfloat16
template <typename src_type, typename dest_type>
__device__ dest_type fi_cast(src_type a)
{
}
template <>
__device__ float fi_cast<BF16, float>(BF16 a)
{
return __bfloat162float(a);
}
template <>
__device__ float fi_cast<FP16, float>(FP16 a)
{
return __half2float(a);
}
template <>
__device__ BF16 fi_cast<float, BF16>(float a)
{
return __float2bfloat16(a);
}
template <>
__device__ FP16 fi_cast<float, FP16>(float a)
{
return __float2half(a);
}
#endif