Move bfloat16 check to worker (#1259)

This commit is contained in:
Antoni Baum 2023-10-07 22:10:44 -07:00 committed by GitHub
parent 09ff7f106a
commit ee92b58b3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 9 deletions

View File

@ -345,15 +345,6 @@ def _get_and_verify_dtype(
# Casting between float16 and bfloat16 is allowed with a warning. # Casting between float16 and bfloat16 is allowed with a warning.
logger.warning(f"Casting {config_dtype} to {torch_dtype}.") logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")
return torch_dtype return torch_dtype

View File

@ -59,6 +59,8 @@ class Worker:
raise ValueError("Invalid or unspecified rank.") raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment. # Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank, _init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method) self.distributed_init_method)
@ -385,3 +387,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
f"(required shared memory {required_shared_mem} > " f"(required shared memory {required_shared_mem} > "
f"available shared memory {max_shared_mem}). " f"available shared memory {max_shared_mem}). "
"This will be fixed in a future release.") "This will be fixed in a future release.")
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")