diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 6d070c65..eb0d75f1 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -28,6 +28,7 @@ namespace gptq { #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #if defined(USE_ROCM) +#include __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, @@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel( zeros_tmp[tmp_k] = zero; } for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); +#ifndef USE_ROCM res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif } i += width; k += 4; diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 6fb072a0..181c970e 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from - `ROCm `_ - `Pytorch `_ +- `hipBLAS `_ 1. Install `flash attention for ROCm `_ diff --git a/setup.py b/setup.py index 45a18776..811d494e 100644 --- a/setup.py +++ b/setup.py @@ -219,13 +219,13 @@ vllm_extension_sources = [ "csrc/activation_kernels.cu", "csrc/layernorm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", "csrc/pybind.cpp", ] if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") - vllm_extension_sources.append("csrc/quantization/gptq/q_gemm.cu") vllm_extension = CUDAExtension( name="vllm._C", diff --git a/vllm/config.py b/vllm/config.py index 353189f6..ff9a1308 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -112,24 +112,20 @@ class ModelConfig: supported_load_format = [ "auto", "pt", "safetensors", "npcache", "dummy" ] - rocm_not_supported_load_format = ["safetensors"] + rocm_not_supported_load_format = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") - if is_hip(): - if load_format in ["safetensors"]: - rocm_supported_load_format = [ - f for f in supported_load_format - if (f not in rocm_not_supported_load_format) - ] - raise ValueError( - f"load format \'{load_format}\' is not supported in ROCm. " - f"Supported load format are " - f"{rocm_supported_load_format}") - # Force ROCm to load from pt weights if nothing specific is set - if load_format == "auto": - load_format = "pt" + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in supported_load_format + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format \'{load_format}\' is not supported in ROCm. " + f"Supported load format are " + f"{rocm_supported_load_format}") # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) @@ -149,7 +145,7 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = ["awq", "gptq", "squeezellm"] - rocm_not_supported_quantization = ["awq", "gptq"] + rocm_not_supported_quantization = ["awq"] if self.quantization is not None: self.quantization = self.quantization.lower()