[Core] [Frontend] Priority scheduling for embeddings and in the OpenAI-API (#8965)

This commit is contained in:
Sebastian Schoennenbeck 2024-10-01 11:58:06 +02:00 committed by GitHub
parent 1fe0a4264a
commit 35bd215168
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 53 additions and 5 deletions

View File

@ -1043,6 +1043,7 @@ class AsyncLLMEngine:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
@ -1057,6 +1058,8 @@ class AsyncLLMEngine:
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
@ -1109,6 +1112,7 @@ class AsyncLLMEngine:
pooling_params, pooling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority,
): ):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput) yield LLMEngine.validate_output(output, EmbeddingRequestOutput)

View File

@ -30,6 +30,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
@overload # DEPRECATED @overload # DEPRECATED
def __init__( def __init__(
@ -41,6 +42,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
... ...
@ -53,6 +55,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
... ...
@ -68,6 +71,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*, *,
inputs: Optional[PromptType] = None, # DEPRECATED inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> None:
@ -84,6 +88,7 @@ class RPCProcessRequest:
self.lora_request = lora_request self.lora_request = lora_request
self.trace_headers = trace_headers self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
@dataclass @dataclass

View File

@ -380,6 +380,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
... ...
@ -392,6 +393,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
... ...
@ -407,6 +409,7 @@ class MQLLMEngineClient:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@ -425,6 +428,9 @@ class MQLLMEngineClient:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
""" """
if inputs is not None: if inputs is not None:
prompt = inputs prompt = inputs
@ -433,7 +439,7 @@ class MQLLMEngineClient:
return self._process_request(prompt, sampling_params, request_id, return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
prompt_adapter_request) prompt_adapter_request, priority)
@overload # DEPRECATED @overload # DEPRECATED
def encode( def encode(
@ -444,6 +450,7 @@ class MQLLMEngineClient:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
... ...
@ -455,6 +462,7 @@ class MQLLMEngineClient:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
... ...
@ -469,6 +477,7 @@ class MQLLMEngineClient:
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
@ -496,7 +505,7 @@ class MQLLMEngineClient:
and request_id is not None) and request_id is not None)
return self._process_request(prompt, pooling_params, request_id, return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers) lora_request, trace_headers, priority)
async def _process_request( async def _process_request(
self, self,
@ -505,7 +514,8 @@ class MQLLMEngineClient:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]: EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
@ -550,7 +560,9 @@ class MQLLMEngineClient:
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)) prompt_adapter_request=prompt_adapter_request,
priority=priority,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine. # 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes, parts = (request_bytes,

View File

@ -40,7 +40,8 @@ class EngineClient(Protocol):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.""" """Generate outputs for a request."""
... ...
@ -52,6 +53,7 @@ class EngineClient(Protocol):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.""" """Generate outputs for a request from an embedding model."""
... ...

View File

@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-completion-extra-params # doc: end-completion-extra-params
@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel):
# doc: end-embedding-pooling-params # doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-embedding-extra-params
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(additional_data=self.additional_data)

View File

@ -235,6 +235,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
) )
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error

View File

@ -148,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority,
) )
generators.append(generator) generators.append(generator)

View File

@ -148,6 +148,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
priority=request.priority,
) )
generators.append(generator) generators.append(generator)