简单修改一下。
This commit is contained in:
parent
0a6b5493fa
commit
80d7be70a5
1
csrc/fp8_vec.cu
Normal file
1
csrc/fp8_vec.cu
Normal file
@ -0,0 +1 @@
|
||||
#include <cuda_fp8.h>
|
||||
35
csrc/type_utils.h
Normal file
35
csrc/type_utils.h
Normal file
@ -0,0 +1,35 @@
|
||||
#ifndef TYPE_UTILS_H
|
||||
#define TYPE_UTILS_H
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#define FP16 __half
|
||||
#define BF16 __nv_bfloat16
|
||||
template <typename src_type, typename dest_type>
|
||||
__device__ dest_type fi_cast(src_type a)
|
||||
{
|
||||
}
|
||||
template <>
|
||||
__device__ float fi_cast<BF16, float>(BF16 a)
|
||||
{
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float fi_cast<FP16, float>(FP16 a)
|
||||
{
|
||||
return __half2float(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ BF16 fi_cast<float, BF16>(float a)
|
||||
{
|
||||
return __float2bfloat16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ FP16 fi_cast<float, FP16>(float a)
|
||||
{
|
||||
return __float2half(a);
|
||||
}
|
||||
|
||||
#endif
|
||||
Loading…
Reference in New Issue
Block a user