From 2ac4d5e2bf033306fdb9b5002b2adbaafb864a3a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 21 Sep 2023 00:51:47 -0700 Subject: [PATCH] Replace DtypeTensor (#1123) --- vllm/worker/worker.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2d2021d9..586c90e0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -228,15 +228,25 @@ class Worker: input_positions = _pad_to_alignment(input_positions, multiple_of=8) # Convert to tensors. - tokens_tensor = torch.cuda.LongTensor(input_tokens) - positions_tensor = torch.cuda.LongTensor(input_positions) - slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) - context_lens_tensor = torch.cuda.IntTensor(context_lens) + tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device="cuda") + positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device="cuda") + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.int, + device="cuda") + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") padded_block_tables = [ _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in generation_block_tables ] - block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) + block_tables_tensor = torch.tensor(padded_block_tables, + dtype=torch.int, + device="cuda") seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: