[Minor] Revert change in offline inference example (#10545)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
cf656f5a02
commit
46fe9b46d8
@ -1,80 +1,22 @@
|
||||
from dataclasses import asdict
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
def get_prompts(num_prompts: int):
|
||||
# The default sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
if num_prompts != len(prompts):
|
||||
prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts]
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def main(args):
|
||||
# Create prompts
|
||||
prompts = get_prompts(args.num_prompts)
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(n=args.n,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
max_tokens=args.max_tokens)
|
||||
|
||||
# Create an LLM.
|
||||
# The default model is 'facebook/opt-125m'
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm = LLM(**asdict(engine_args))
|
||||
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
group = parser.add_argument_group("SamplingParams options")
|
||||
group.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of prompts used for inference")
|
||||
group.add_argument("--max-tokens",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Generated output length for sampling")
|
||||
group.add_argument('--n',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of generated sequences per prompt')
|
||||
group.add_argument('--temperature',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Temperature for text generation')
|
||||
group.add_argument('--top-p',
|
||||
type=float,
|
||||
default=0.95,
|
||||
help='top_p for text generation')
|
||||
group.add_argument('--top-k',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='top_k for text generation')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
80
examples/offline_inference_cli.py
Normal file
80
examples/offline_inference_cli.py
Normal file
@ -0,0 +1,80 @@
|
||||
from dataclasses import asdict
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def get_prompts(num_prompts: int):
|
||||
# The default sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
if num_prompts != len(prompts):
|
||||
prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts]
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def main(args):
|
||||
# Create prompts
|
||||
prompts = get_prompts(args.num_prompts)
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(n=args.n,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
max_tokens=args.max_tokens)
|
||||
|
||||
# Create an LLM.
|
||||
# The default model is 'facebook/opt-125m'
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm = LLM(**asdict(engine_args))
|
||||
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
group = parser.add_argument_group("SamplingParams options")
|
||||
group.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of prompts used for inference")
|
||||
group.add_argument("--max-tokens",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Generated output length for sampling")
|
||||
group.add_argument('--n',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of generated sequences per prompt')
|
||||
group.add_argument('--temperature',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Temperature for text generation')
|
||||
group.add_argument('--top-p',
|
||||
type=float,
|
||||
default=0.95,
|
||||
help='top_p for text generation')
|
||||
group.add_argument('--top-k',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='top_k for text generation')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Loading…
Reference in New Issue
Block a user