diff --git a/vllm/config.py b/vllm/config.py index 7a941798..f66fb291 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -345,15 +345,6 @@ def _get_and_verify_dtype( # Casting between float16 and bfloat16 is allowed with a warning. logger.warning(f"Casting {config_dtype} to {torch_dtype}.") - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: - compute_capability = torch.cuda.get_device_capability() - if compute_capability[0] < 8: - gpu_name = torch.cuda.get_device_name() - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}.") return torch_dtype diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6fbc155d..5b0a60db 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -59,6 +59,8 @@ class Worker: raise ValueError("Invalid or unspecified rank.") torch.cuda.set_device(self.device) + _check_if_gpu_supports_dtype(self.model_config.dtype) + # Initialize the distributed environment. _init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method) @@ -385,3 +387,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int, f"(required shared memory {required_shared_mem} > " f"available shared memory {max_shared_mem}). " "This will be fixed in a future release.") + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] < 8: + gpu_name = torch.cuda.get_device_name() + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}.")