From 80d7be70a54b98194819b96346034f04d44101fd Mon Sep 17 00:00:00 2001 From: long0x0 Date: Sun, 29 Dec 2024 15:49:53 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AE=80=E5=8D=95=E4=BF=AE=E6=94=B9=E4=B8=80?= =?UTF-8?q?=E4=B8=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- csrc/fp8_vec.cu | 1 + csrc/type_utils.h | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 csrc/fp8_vec.cu create mode 100644 csrc/type_utils.h diff --git a/csrc/fp8_vec.cu b/csrc/fp8_vec.cu new file mode 100644 index 0000000..db7b3a4 --- /dev/null +++ b/csrc/fp8_vec.cu @@ -0,0 +1 @@ +#include \ No newline at end of file diff --git a/csrc/type_utils.h b/csrc/type_utils.h new file mode 100644 index 0000000..8dcf7ba --- /dev/null +++ b/csrc/type_utils.h @@ -0,0 +1,35 @@ +#ifndef TYPE_UTILS_H +#define TYPE_UTILS_H +#include +#include +#define FP16 __half +#define BF16 __nv_bfloat16 +template +__device__ dest_type fi_cast(src_type a) +{ +} +template <> +__device__ float fi_cast(BF16 a) +{ + return __bfloat162float(a); +} + +template <> +__device__ float fi_cast(FP16 a) +{ + return __half2float(a); +} + +template <> +__device__ BF16 fi_cast(float a) +{ + return __float2bfloat16(a); +} + +template <> +__device__ FP16 fi_cast(float a) +{ + return __float2half(a); +} + +#endif \ No newline at end of file