[Bugfix] update neuron for version > 0.5.0 (#7175)

Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
omrishiv 2024-08-15 09:44:14 -07:00 committed by GitHub
parent fc93e56143
commit 9c1f78d5d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 7 additions and 4 deletions

View File

@ -316,7 +316,7 @@ class EngineArgs:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
help='Token block size for contiguous chunks of '
'tokens.')

View File

@ -100,9 +100,8 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(
self.driver_worker.execute_model
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:

View File

@ -197,6 +197,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt

View File

@ -89,6 +89,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return WorkerInput(num_seq_groups=len(
execute_model_req.seq_group_metadata_list), )
def execute_worker(self, worker_input: WorkerInput) -> None:
pass
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.