Replace DtypeTensor (#1123)
This commit is contained in:
parent
3302f0aef3
commit
2ac4d5e2bf
@ -228,15 +228,25 @@ class Worker:
|
|||||||
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
||||||
|
|
||||||
# Convert to tensors.
|
# Convert to tensors.
|
||||||
tokens_tensor = torch.cuda.LongTensor(input_tokens)
|
tokens_tensor = torch.tensor(input_tokens,
|
||||||
positions_tensor = torch.cuda.LongTensor(input_positions)
|
dtype=torch.long,
|
||||||
slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping)
|
device="cuda")
|
||||||
context_lens_tensor = torch.cuda.IntTensor(context_lens)
|
positions_tensor = torch.tensor(input_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda")
|
||||||
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda")
|
||||||
padded_block_tables = [
|
padded_block_tables = [
|
||||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||||
for block_table in generation_block_tables
|
for block_table in generation_block_tables
|
||||||
]
|
]
|
||||||
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
|
block_tables_tensor = torch.tensor(padded_block_tables,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
seq_data: Dict[int, SequenceData] = {}
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user