diff --git a/csrc/fp8_vec.cu b/csrc/fp8_vec.cu new file mode 100644 index 0000000..db7b3a4 --- /dev/null +++ b/csrc/fp8_vec.cu @@ -0,0 +1 @@ +#include \ No newline at end of file diff --git a/csrc/type_utils.h b/csrc/type_utils.h new file mode 100644 index 0000000..8dcf7ba --- /dev/null +++ b/csrc/type_utils.h @@ -0,0 +1,35 @@ +#ifndef TYPE_UTILS_H +#define TYPE_UTILS_H +#include +#include +#define FP16 __half +#define BF16 __nv_bfloat16 +template +__device__ dest_type fi_cast(src_type a) +{ +} +template <> +__device__ float fi_cast(BF16 a) +{ + return __bfloat162float(a); +} + +template <> +__device__ float fi_cast(FP16 a) +{ + return __half2float(a); +} + +template <> +__device__ BF16 fi_cast(float a) +{ + return __float2bfloat16(a); +} + +template <> +__device__ FP16 fi_cast(float a) +{ + return __float2half(a); +} + +#endif \ No newline at end of file