[Core] [Frontend] Priority scheduling for embeddings and in the OpenAI-API (#8965)
This commit is contained in:
parent
1fe0a4264a
commit
35bd215168
@ -1043,6 +1043,7 @@ class AsyncLLMEngine:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
@ -1057,6 +1058,8 @@ class AsyncLLMEngine:
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
priority: The priority of the request.
|
||||
Only applicable with priority scheduling.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
@ -1109,6 +1112,7 @@ class AsyncLLMEngine:
|
||||
pooling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ class RPCProcessRequest:
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
priority: int = 0
|
||||
|
||||
@overload # DEPRECATED
|
||||
def __init__(
|
||||
@ -41,6 +42,7 @@ class RPCProcessRequest:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@ -53,6 +55,7 @@ class RPCProcessRequest:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@ -68,6 +71,7 @@ class RPCProcessRequest:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
@ -84,6 +88,7 @@ class RPCProcessRequest:
|
||||
self.lora_request = lora_request
|
||||
self.trace_headers = trace_headers
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.priority = priority
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -380,6 +380,7 @@ class MQLLMEngineClient:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
...
|
||||
|
||||
@ -392,6 +393,7 @@ class MQLLMEngineClient:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
...
|
||||
|
||||
@ -407,6 +409,7 @@ class MQLLMEngineClient:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None # DEPRECATED
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
@ -425,6 +428,9 @@ class MQLLMEngineClient:
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request to use
|
||||
for generation, if any.
|
||||
priority: Priority of the request (lower means earlier handling).
|
||||
Any priority other than 0 will lead to an error if the
|
||||
scheduling policy is not "priority".
|
||||
"""
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
@ -433,7 +439,7 @@ class MQLLMEngineClient:
|
||||
|
||||
return self._process_request(prompt, sampling_params, request_id,
|
||||
lora_request, trace_headers,
|
||||
prompt_adapter_request)
|
||||
prompt_adapter_request, priority)
|
||||
|
||||
@overload # DEPRECATED
|
||||
def encode(
|
||||
@ -444,6 +450,7 @@ class MQLLMEngineClient:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
...
|
||||
|
||||
@ -455,6 +462,7 @@ class MQLLMEngineClient:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
...
|
||||
|
||||
@ -469,6 +477,7 @@ class MQLLMEngineClient:
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None # DEPRECATED
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
@ -496,7 +505,7 @@ class MQLLMEngineClient:
|
||||
and request_id is not None)
|
||||
|
||||
return self._process_request(prompt, pooling_params, request_id,
|
||||
lora_request, trace_headers)
|
||||
lora_request, trace_headers, priority)
|
||||
|
||||
async def _process_request(
|
||||
self,
|
||||
@ -505,7 +514,8 @@ class MQLLMEngineClient:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
||||
EmbeddingRequestOutput, None]]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
@ -550,7 +560,9 @@ class MQLLMEngineClient:
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request))
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
))
|
||||
|
||||
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
||||
parts = (request_bytes,
|
||||
|
||||
@ -40,7 +40,8 @@ class EngineClient(Protocol):
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request."""
|
||||
...
|
||||
@ -52,6 +53,7 @@ class EngineClient(Protocol):
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model."""
|
||||
...
|
||||
|
||||
@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-embedding-pooling-params
|
||||
|
||||
# doc: begin-embedding-extra-params
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-embedding-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
@ -235,6 +235,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
|
||||
@ -148,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
@ -148,6 +148,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user