[Minor] Small fix to make distributed init logic in worker looks cleaner (#2905)

This commit is contained in:
Zhuohan Li 2024-02-18 14:39:00 -08:00 committed by GitHub
parent 786b7f18a5
commit 537c9755a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -93,8 +93,6 @@ class Worker:
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
cupy_port, self.distributed_init_method)
if not self.parallel_config.disable_custom_all_reduce:
init_custom_ar()
# Initialize the model.
set_random_seed(self.model_config.seed)
@ -288,6 +286,10 @@ def init_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
init_custom_ar()
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.