workaround of AWQ for Turing GPUs (#1252)
This commit is contained in:
parent
91fce82c6f
commit
8285736840
@ -16,7 +16,7 @@ namespace awq {
|
|||||||
|
|
||||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||||
{
|
{
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
uint4 result;
|
uint4 result;
|
||||||
|
|||||||
@ -29,7 +29,7 @@ __pack_half2(const half x, const half y) {
|
|||||||
|
|
||||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
{
|
{
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
static constexpr uint32_t ZERO = 0x0;
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
@ -191,6 +191,39 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
#else
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
@ -206,6 +239,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -226,7 +261,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
|
|
||||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
{
|
{
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
static constexpr uint32_t ZERO = 0x0;
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
@ -392,7 +427,39 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
|
|||||||
|
|
||||||
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||||
{
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
#else
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
@ -408,6 +475,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
|
|||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,8 +42,8 @@ class AWQConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
# The AWQ kernel only supports Ampere or newer GPUs.
|
# The AWQ kernel only supports Turing or newer GPUs.
|
||||||
return 80
|
return 75
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user