fix some bugs (#2689)

This commit is contained in:
zspo 2024-02-01 02:09:23 +08:00 committed by GitHub
parent d69ff0cbbb
commit c664b0e683
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions

View File

@ -355,6 +355,9 @@ class ParallelConfig:
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
"""
@ -470,7 +473,7 @@ class LoRAConfig:
elif self.max_cpu_loras < self.max_loras:
raise ValueError(
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_num_seqs ({self.max_loras})")
f"max_loras ({self.max_loras})")
def verify_with_model_config(self, model_config: ModelConfig):
if self.lora_dtype in (None, "auto"):

View File

@ -296,6 +296,8 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for LLMEngine.
@ -431,8 +433,8 @@ class AsyncLLMEngine:
logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
f"sampling params: {sampling_params}, "
f"prompt token ids: {shortened_token_ids}, "
f"sampling_params: {sampling_params}, "
f"prompt_token_ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
if not self.is_running: