[Bugfix] Fix Llava inference with Tensor Parallelism. (#3883)

This commit is contained in:
Isotr0py 2024-04-07 22:54:13 +08:00 committed by GitHub
parent 2f19283549
commit 0ce0539d47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -154,6 +154,7 @@ class RayGPUExecutor(ExecutorBase):
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
vision_language_config = copy.deepcopy(self.vision_language_config)
kv_cache_dtype = self.cache_config.cache_dtype
# Initialize the actual workers with the Worker class.
@ -172,6 +173,7 @@ class RayGPUExecutor(ExecutorBase):
rank,
distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
kv_cache_dtype=kv_cache_dtype,
))