[V1] Allow tokenizer_mode and trust_remote_code for Detokenizer (#10211)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2024-11-11 02:01:18 -08:00 committed by GitHub
parent 36e4acd02a
commit 5fb1f935b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 5 deletions

View File

@ -125,7 +125,10 @@ class LLMEngine:
# Ping the tokenizer to ensure liveness if it runs in a # Ping the tokenizer to ensure liveness if it runs in a
# different process. # different process.
self.tokenizer.ping() self.tokenizer.ping()
self.detokenizer = Detokenizer(self.model_config.tokenizer) self.detokenizer = Detokenizer(
tokenizer_name=self.model_config.tokenizer,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code)
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = _load_generation_config_dict(
model_config) model_config)

View File

@ -42,13 +42,17 @@ class DetokenizerOutputs(msgspec.Struct):
class Detokenizer: class Detokenizer:
def __init__(self, tokenizer_name: str): def __init__(self, tokenizer_name: str, tokenizer_mode: str,
trust_remote_code: bool):
# FIXME(woosuk): Currently, the detokenizer is just a hacky prototype. # FIXME(woosuk): Currently, the detokenizer is just a hacky prototype.
# For example, it does not terminate properly. We need to improve this. # For example, it does not terminate properly. We need to improve this.
self.push_port = get_open_port() self.push_port = get_open_port()
self.pull_port = get_open_port() self.pull_port = get_open_port()
self.detokenizer = DetokenizerProc(tokenizer_name, self.push_port, self.detokenizer = DetokenizerProc(tokenizer_name=tokenizer_name,
self.pull_port) tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
push_port=self.push_port,
pull_port=self.pull_port)
self.detokenizer.start() self.detokenizer.start()
self.zmq_context = zmq.Context() self.zmq_context = zmq.Context()
@ -82,11 +86,15 @@ class DetokenizerProc(multiprocessing.Process):
def __init__( def __init__(
self, self,
tokenizer_name: str, tokenizer_name: str,
tokenizer_mode: str,
trust_remote_code: bool,
pull_port: int, pull_port: int,
push_port: int, push_port: int,
): ):
super().__init__() super().__init__()
self.tokenizer_name = tokenizer_name self.tokenizer_name = tokenizer_name
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
# NOTE: The pull_port of the detokenizer should be the same as the # NOTE: The pull_port of the detokenizer should be the same as the
# push_port of the engine. Vice versa. # push_port of the engine. Vice versa.
self.pull_port = pull_port self.pull_port = pull_port
@ -97,7 +105,10 @@ class DetokenizerProc(multiprocessing.Process):
# not picklable. # not picklable.
self.msgpack_encoder = msgpack.Encoder() self.msgpack_encoder = msgpack.Encoder()
self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs) self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs)
self.tokenizer = get_tokenizer(self.tokenizer_name) self.tokenizer = get_tokenizer(
tokenizer_name=self.tokenizer_name,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code)
# req_id -> RequestState # req_id -> RequestState
self.request_states: Dict[str, RequestState] = {} self.request_states: Dict[str, RequestState] = {}