Move bfloat16 check to worker (#1259)
This commit is contained in:
parent
09ff7f106a
commit
ee92b58b3a
@ -345,15 +345,6 @@ def _get_and_verify_dtype(
|
|||||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||||
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
||||||
|
|
||||||
# Check if the GPU supports the dtype.
|
|
||||||
if torch_dtype == torch.bfloat16:
|
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
|
||||||
if compute_capability[0] < 8:
|
|
||||||
gpu_name = torch.cuda.get_device_name()
|
|
||||||
raise ValueError(
|
|
||||||
"Bfloat16 is only supported on GPUs with compute capability "
|
|
||||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
|
||||||
f"{compute_capability[0]}.{compute_capability[1]}.")
|
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -59,6 +59,8 @@ class Worker:
|
|||||||
raise ValueError("Invalid or unspecified rank.")
|
raise ValueError("Invalid or unspecified rank.")
|
||||||
torch.cuda.set_device(self.device)
|
torch.cuda.set_device(self.device)
|
||||||
|
|
||||||
|
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||||
|
|
||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
_init_distributed_environment(self.parallel_config, self.rank,
|
_init_distributed_environment(self.parallel_config, self.rank,
|
||||||
self.distributed_init_method)
|
self.distributed_init_method)
|
||||||
@ -385,3 +387,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
|
|||||||
f"(required shared memory {required_shared_mem} > "
|
f"(required shared memory {required_shared_mem} > "
|
||||||
f"available shared memory {max_shared_mem}). "
|
f"available shared memory {max_shared_mem}). "
|
||||||
"This will be fixed in a future release.")
|
"This will be fixed in a future release.")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
|
# Check if the GPU supports the dtype.
|
||||||
|
if torch_dtype == torch.bfloat16:
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
if compute_capability[0] < 8:
|
||||||
|
gpu_name = torch.cuda.get_device_name()
|
||||||
|
raise ValueError(
|
||||||
|
"Bfloat16 is only supported on GPUs with compute capability "
|
||||||
|
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||||
|
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user