[Core] [Frontend] Priority scheduling for embeddings and in the OpenAI-API (#8965)

This commit is contained in:
Sebastian Schoennenbeck 2024-10-01 11:58:06 +02:00 committed by GitHub
parent 1fe0a4264a
commit 35bd215168
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 53 additions and 5 deletions

View File

@ -1043,6 +1043,7 @@ class AsyncLLMEngine:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
@ -1057,6 +1058,8 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
@ -1109,6 +1112,7 @@ class AsyncLLMEngine:
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)

View File

@ -30,6 +30,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
@overload # DEPRECATED
def __init__(
@ -41,6 +42,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@ -53,6 +55,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@ -68,6 +71,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@ -84,6 +88,7 @@ class RPCProcessRequest:
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
@dataclass

View File

@ -380,6 +380,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
@ -392,6 +393,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
@ -407,6 +409,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
@ -425,6 +428,9 @@ class MQLLMEngineClient:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
if inputs is not None:
prompt = inputs
@ -433,7 +439,7 @@ class MQLLMEngineClient:
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
prompt_adapter_request, priority)
@overload # DEPRECATED
def encode(
@ -444,6 +450,7 @@ class MQLLMEngineClient:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@ -455,6 +462,7 @@ class MQLLMEngineClient:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@ -469,6 +477,7 @@ class MQLLMEngineClient:
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
@ -496,7 +505,7 @@ class MQLLMEngineClient:
and request_id is not None)
return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers)
lora_request, trace_headers, priority)
async def _process_request(
self,
@ -505,7 +514,8 @@ class MQLLMEngineClient:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@ -550,7 +560,9 @@ class MQLLMEngineClient:
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
prompt_adapter_request=prompt_adapter_request,
priority=priority,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,

View File

@ -40,7 +40,8 @@ class EngineClient(Protocol):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
@ -52,6 +53,7 @@ class EngineClient(Protocol):
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...

View File

@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-chat-completion-extra-params
@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-completion-extra-params
@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel):
# doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-embedding-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)

View File

@ -235,6 +235,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error

View File

@ -148,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)

View File

@ -148,6 +148,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params,
request_id_item,
lora_request=lora_request,
priority=request.priority,
)
generators.append(generator)