Fix num_gpus when TP > 1 (#1852)
This commit is contained in:
parent
c07a442854
commit
464dd985e3
@ -301,7 +301,16 @@ class AsyncLLMEngine:
|
|||||||
elif self.worker_use_ray:
|
elif self.worker_use_ray:
|
||||||
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||||
else:
|
else:
|
||||||
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
|
||||||
|
# order of the arguments.
|
||||||
|
cache_config = args[1]
|
||||||
|
parallel_config = args[2]
|
||||||
|
if parallel_config.tensor_parallel_size == 1:
|
||||||
|
num_gpus = cache_config.gpu_memory_utilization
|
||||||
|
else:
|
||||||
|
num_gpus = 1
|
||||||
|
engine_class = ray.remote(num_gpus=num_gpus)(
|
||||||
|
self._engine_class).remote
|
||||||
return engine_class(*args, **kwargs)
|
return engine_class(*args, **kwargs)
|
||||||
|
|
||||||
async def engine_step(self) -> bool:
|
async def engine_step(self) -> bool:
|
||||||
|
|||||||
@ -159,9 +159,13 @@ class LLMEngine:
|
|||||||
for bundle in placement_group.bundle_specs:
|
for bundle in placement_group.bundle_specs:
|
||||||
if not bundle.get("GPU", 0):
|
if not bundle.get("GPU", 0):
|
||||||
continue
|
continue
|
||||||
|
if self.parallel_config.tensor_parallel_size == 1:
|
||||||
|
num_gpus = self.cache_config.gpu_memory_utilization
|
||||||
|
else:
|
||||||
|
num_gpus = 1
|
||||||
worker = ray.remote(
|
worker = ray.remote(
|
||||||
num_cpus=0,
|
num_cpus=0,
|
||||||
num_gpus=self.cache_config.gpu_memory_utilization,
|
num_gpus=num_gpus,
|
||||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
placement_group_capture_child_tasks=True),
|
placement_group_capture_child_tasks=True),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user