[Server] Add option to specify chat template for chat endpoint (#345)
This commit is contained in:
parent
ffa6d2f9f9
commit
98fe8cb542
@ -9,3 +9,4 @@ xformers >= 0.0.19
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
pydantic # Required for OpenAI server.
|
pydantic # Required for OpenAI server.
|
||||||
|
fschat # Required for OpenAI ChatCompletion Endpoint.
|
||||||
|
|||||||
@ -36,6 +36,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
served_model = None
|
served_model = None
|
||||||
|
chat_template = None
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
|
|||||||
|
|
||||||
|
|
||||||
async def get_gen_prompt(request) -> str:
|
async def get_gen_prompt(request) -> str:
|
||||||
conv = get_conv_template(request.model)
|
conv = get_conv_template(chat_template)
|
||||||
conv = Conversation(
|
conv = Conversation(
|
||||||
name=conv.name,
|
name=conv.name,
|
||||||
system=conv.system,
|
system=conv.system,
|
||||||
@ -553,13 +554,20 @@ if __name__ == "__main__":
|
|||||||
type=json.loads,
|
type=json.loads,
|
||||||
default=["*"],
|
default=["*"],
|
||||||
help="allowed headers")
|
help="allowed headers")
|
||||||
|
parser.add_argument("--served-model-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The model name used in the API. If not "
|
||||||
|
"specified, the model name will be the same as "
|
||||||
|
"the huggingface name.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--served-model-name",
|
"--chat-template",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The model name used in the API. If not specified, "
|
help="The chat template name used in the ChatCompletion endpoint. If "
|
||||||
"the model name will be the same as the "
|
"not specified, we use the API model name as the template name. See "
|
||||||
"huggingface name.")
|
"https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
|
||||||
|
"for the list of available templates.")
|
||||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -573,7 +581,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
logger.info(f"args: {args}")
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
served_model = args.served_model_name or args.model
|
if args.served_model_name is not None:
|
||||||
|
served_model = args.served_model_name
|
||||||
|
else:
|
||||||
|
served_model = args.model
|
||||||
|
|
||||||
|
if args.chat_template is not None:
|
||||||
|
chat_template = args.chat_template
|
||||||
|
else:
|
||||||
|
chat_template = served_model
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user