[Minor] Small fix to make distributed init logic in worker looks cleaner (#2905)
This commit is contained in:
parent
786b7f18a5
commit
537c9755a7
@ -93,8 +93,6 @@ class Worker:
|
|||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
init_distributed_environment(self.parallel_config, self.rank,
|
init_distributed_environment(self.parallel_config, self.rank,
|
||||||
cupy_port, self.distributed_init_method)
|
cupy_port, self.distributed_init_method)
|
||||||
if not self.parallel_config.disable_custom_all_reduce:
|
|
||||||
init_custom_ar()
|
|
||||||
# Initialize the model.
|
# Initialize the model.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
@ -288,6 +286,10 @@ def init_distributed_environment(
|
|||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
# Initialize a custom fast all-reduce implementation.
|
||||||
|
if not parallel_config.disable_custom_all_reduce:
|
||||||
|
init_custom_ar()
|
||||||
|
|
||||||
|
|
||||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
# Check if the GPU supports the dtype.
|
# Check if the GPU supports the dtype.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user