[Core] [Frontend] Priority scheduling for embeddings and in the OpenAI-API (#8965)
This commit is contained in:
parent
1fe0a4264a
commit
35bd215168
@ -1043,6 +1043,7 @@ class AsyncLLMEngine:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
"""Generate outputs for a request from an embedding model.
|
"""Generate outputs for a request from an embedding model.
|
||||||
|
|
||||||
@ -1057,6 +1058,8 @@ class AsyncLLMEngine:
|
|||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
trace_headers: OpenTelemetry trace headers.
|
trace_headers: OpenTelemetry trace headers.
|
||||||
|
priority: The priority of the request.
|
||||||
|
Only applicable with priority scheduling.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||||
@ -1109,6 +1112,7 @@ class AsyncLLMEngine:
|
|||||||
pooling_params,
|
pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
|
priority=priority,
|
||||||
):
|
):
|
||||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,7 @@ class RPCProcessRequest:
|
|||||||
lora_request: Optional[LoRARequest] = None
|
lora_request: Optional[LoRARequest] = None
|
||||||
trace_headers: Optional[Mapping[str, str]] = None
|
trace_headers: Optional[Mapping[str, str]] = None
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
|
priority: int = 0
|
||||||
|
|
||||||
@overload # DEPRECATED
|
@overload # DEPRECATED
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -41,6 +42,7 @@ class RPCProcessRequest:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -53,6 +55,7 @@ class RPCProcessRequest:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -68,6 +71,7 @@ class RPCProcessRequest:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
*,
|
*,
|
||||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -84,6 +88,7 @@ class RPCProcessRequest:
|
|||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.trace_headers = trace_headers
|
self.trace_headers = trace_headers
|
||||||
self.prompt_adapter_request = prompt_adapter_request
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
|
self.priority = priority
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -380,6 +380,7 @@ class MQLLMEngineClient:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -392,6 +393,7 @@ class MQLLMEngineClient:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -407,6 +409,7 @@ class MQLLMEngineClient:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
*,
|
*,
|
||||||
inputs: Optional[PromptType] = None # DEPRECATED
|
inputs: Optional[PromptType] = None # DEPRECATED
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
@ -425,6 +428,9 @@ class MQLLMEngineClient:
|
|||||||
trace_headers: OpenTelemetry trace headers.
|
trace_headers: OpenTelemetry trace headers.
|
||||||
prompt_adapter_request: Prompt Adapter request to use
|
prompt_adapter_request: Prompt Adapter request to use
|
||||||
for generation, if any.
|
for generation, if any.
|
||||||
|
priority: Priority of the request (lower means earlier handling).
|
||||||
|
Any priority other than 0 will lead to an error if the
|
||||||
|
scheduling policy is not "priority".
|
||||||
"""
|
"""
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
prompt = inputs
|
prompt = inputs
|
||||||
@ -433,7 +439,7 @@ class MQLLMEngineClient:
|
|||||||
|
|
||||||
return self._process_request(prompt, sampling_params, request_id,
|
return self._process_request(prompt, sampling_params, request_id,
|
||||||
lora_request, trace_headers,
|
lora_request, trace_headers,
|
||||||
prompt_adapter_request)
|
prompt_adapter_request, priority)
|
||||||
|
|
||||||
@overload # DEPRECATED
|
@overload # DEPRECATED
|
||||||
def encode(
|
def encode(
|
||||||
@ -444,6 +450,7 @@ class MQLLMEngineClient:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -455,6 +462,7 @@ class MQLLMEngineClient:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -469,6 +477,7 @@ class MQLLMEngineClient:
|
|||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
priority: int = 0,
|
||||||
*,
|
*,
|
||||||
inputs: Optional[PromptType] = None # DEPRECATED
|
inputs: Optional[PromptType] = None # DEPRECATED
|
||||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
@ -496,7 +505,7 @@ class MQLLMEngineClient:
|
|||||||
and request_id is not None)
|
and request_id is not None)
|
||||||
|
|
||||||
return self._process_request(prompt, pooling_params, request_id,
|
return self._process_request(prompt, pooling_params, request_id,
|
||||||
lora_request, trace_headers)
|
lora_request, trace_headers, priority)
|
||||||
|
|
||||||
async def _process_request(
|
async def _process_request(
|
||||||
self,
|
self,
|
||||||
@ -505,7 +514,8 @@ class MQLLMEngineClient:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
||||||
EmbeddingRequestOutput, None]]:
|
EmbeddingRequestOutput, None]]:
|
||||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||||
@ -550,7 +560,9 @@ class MQLLMEngineClient:
|
|||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request))
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
priority=priority,
|
||||||
|
))
|
||||||
|
|
||||||
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
||||||
parts = (request_bytes,
|
parts = (request_bytes,
|
||||||
|
|||||||
@ -40,7 +40,8 @@ class EngineClient(Protocol):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Generate outputs for a request."""
|
"""Generate outputs for a request."""
|
||||||
...
|
...
|
||||||
@ -52,6 +53,7 @@ class EngineClient(Protocol):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
priority: int = 0,
|
||||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
"""Generate outputs for a request from an embedding model."""
|
"""Generate outputs for a request from an embedding model."""
|
||||||
...
|
...
|
||||||
|
|||||||
@ -279,6 +279,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
description=(
|
description=(
|
||||||
"If specified, will override the default whitespace pattern "
|
"If specified, will override the default whitespace pattern "
|
||||||
"for guided json decoding."))
|
"for guided json decoding."))
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."))
|
||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
@ -552,6 +558,12 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
description=(
|
description=(
|
||||||
"If specified, will override the default whitespace pattern "
|
"If specified, will override the default whitespace pattern "
|
||||||
"for guided json decoding."))
|
"for guided json decoding."))
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."))
|
||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
@ -665,6 +677,16 @@ class EmbeddingRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-embedding-pooling-params
|
# doc: end-embedding-pooling-params
|
||||||
|
|
||||||
|
# doc: begin-embedding-extra-params
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."))
|
||||||
|
|
||||||
|
# doc: end-embedding-extra-params
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(additional_data=self.additional_data)
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
|
|||||||
@ -235,6 +235,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
|||||||
@ -148,6 +148,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|||||||
@ -148,6 +148,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
pooling_params,
|
pooling_params,
|
||||||
request_id_item,
|
request_id_item,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user