Move bfloat16 check to worker (#1259)
This commit is contained in:
parent
09ff7f106a
commit
ee92b58b3a
@ -345,15 +345,6 @@ def _get_and_verify_dtype(
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
if compute_capability[0] < 8:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||
return torch_dtype
|
||||
|
||||
|
||||
|
||||
@ -59,6 +59,8 @@ class Worker:
|
||||
raise ValueError("Invalid or unspecified rank.")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
_init_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method)
|
||||
@ -385,3 +387,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
|
||||
f"(required shared memory {required_shared_mem} > "
|
||||
f"available shared memory {max_shared_mem}). "
|
||||
"This will be fixed in a future release.")
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
if compute_capability[0] < 8:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user