Fix num_gpus when TP > 1 (#1852)

This commit is contained in:
Woosuk Kwon 2023-12-03 12:24:30 -08:00 committed by GitHub
parent c07a442854
commit 464dd985e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 2 deletions

View File

@ -301,7 +301,16 @@ class AsyncLLMEngine:
elif self.worker_use_ray:
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
else:
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
cache_config = args[1]
parallel_config = args[2]
if parallel_config.tensor_parallel_size == 1:
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
engine_class = ray.remote(num_gpus=num_gpus)(
self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self) -> bool:

View File

@ -159,9 +159,13 @@ class LLMEngine:
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
if self.parallel_config.tensor_parallel_size == 1:
num_gpus = self.cache_config.gpu_memory_utilization
else:
num_gpus = 1
worker = ray.remote(
num_cpus=0,
num_gpus=self.cache_config.gpu_memory_utilization,
num_gpus=num_gpus,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),