[Bugfix] Fix ray workers profiling with nsight (#4095)

This commit is contained in:
Ricky Xu 2024-04-15 14:24:45 -07:00 committed by GitHub
parent d619ae2d19
commit 4695397dcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -48,6 +48,21 @@ class RayGPUExecutor(ExecutorBase):
if USE_RAY_COMPILED_DAG: if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag() self.forward_dag = self._compiled_ray_dag()
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1: if self.parallel_config.tensor_parallel_size == 1:
@ -63,6 +78,10 @@ class RayGPUExecutor(ExecutorBase):
# The remaining workers are the actual ray actors. # The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = [] self.workers: List[RayWorkerVllm] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs): for bundle_id, bundle in enumerate(placement_group.bundle_specs):