Speed up Punica compilation (#2632)
This commit is contained in:
parent
5f036d2bcc
commit
f8ecb84c02
@ -5,7 +5,7 @@
|
|||||||
steps:
|
steps:
|
||||||
- label: ":docker: build image"
|
- label: ":docker: build image"
|
||||||
commands:
|
commands:
|
||||||
- "docker build --tag {{ docker_image }} --target test --progress plain ."
|
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||||
- "docker push {{ docker_image }}"
|
- "docker push {{ docker_image }}"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
|||||||
@ -1,21 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
||||||
27
csrc/punica/bgmv/generator.py
Normal file
27
csrc/punica/bgmv/generator.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
DTYPES = ["fp16", "bf16", "fp32"]
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"fp16": "nv_half",
|
||||||
|
"bf16": "nv_bfloat16",
|
||||||
|
"fp32": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEMPLATE = """
|
||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
for input_dtype in DTYPES:
|
||||||
|
for output_dtype in DTYPES:
|
||||||
|
for weight_dtype in DTYPES:
|
||||||
|
if weight_dtype == "fp32":
|
||||||
|
# FP32 weights are not supported.
|
||||||
|
continue
|
||||||
|
kernel_definition = TEMPLATE.format(
|
||||||
|
input_dtype=DTYPE_MAP[input_dtype],
|
||||||
|
output_dtype=DTYPE_MAP[output_dtype],
|
||||||
|
weight_dtype=DTYPE_MAP[weight_dtype])
|
||||||
|
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(kernel_definition)
|
||||||
Loading…
Reference in New Issue
Block a user