diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index b3551129..7c709b60 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -5,7 +5,7 @@ steps: - label: ":docker: build image" commands: - - "docker build --tag {{ docker_image }} --target test --progress plain ." + - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - "docker push {{ docker_image }}" env: DOCKER_BUILDKIT: "1" diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu deleted file mode 100644 index 2502a67e..00000000 --- a/csrc/punica/bgmv/bgmv_all.cu +++ /dev/null @@ -1,21 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu new file mode 100644 index 00000000..c642e949 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu new file mode 100644 index 00000000..e8202dff --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu new file mode 100644 index 00000000..3e7cf31d --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu new file mode 100644 index 00000000..68277fa6 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu new file mode 100644 index 00000000..0607cebf --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu new file mode 100644 index 00000000..3b7531b8 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu new file mode 100644 index 00000000..b3b74aa3 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu new file mode 100644 index 00000000..3cc87f5d --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu new file mode 100644 index 00000000..9eda98bd --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu new file mode 100644 index 00000000..f1db6df5 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu new file mode 100644 index 00000000..060f9ebb --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu new file mode 100644 index 00000000..c01ddd00 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu new file mode 100644 index 00000000..f45183ff --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu new file mode 100644 index 00000000..b37e4457 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu new file mode 100644 index 00000000..06718cbb --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu new file mode 100644 index 00000000..40977434 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu new file mode 100644 index 00000000..41fb0e45 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu new file mode 100644 index 00000000..50b7ead9 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py new file mode 100644 index 00000000..66de56d7 --- /dev/null +++ b/csrc/punica/bgmv/generator.py @@ -0,0 +1,27 @@ +DTYPES = ["fp16", "bf16", "fp32"] +DTYPE_MAP = { + "fp16": "nv_half", + "bf16": "nv_bfloat16", + "fp32": "float", +} + +TEMPLATE = """ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) +""".lstrip() + +for input_dtype in DTYPES: + for output_dtype in DTYPES: + for weight_dtype in DTYPES: + if weight_dtype == "fp32": + # FP32 weights are not supported. + continue + kernel_definition = TEMPLATE.format( + input_dtype=DTYPE_MAP[input_dtype], + output_dtype=DTYPE_MAP[output_dtype], + weight_dtype=DTYPE_MAP[weight_dtype]) + filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu" + with open(filename, "w") as f: + f.write(kernel_definition)