2024-01-17 02:42:51 +08:00
|
|
|
#################################################################################################
|
|
|
|
#
|
|
|
|
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
#
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
#
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
|
|
# list of conditions and the following disclaimer.
|
|
|
|
#
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
|
|
# and/or other materials provided with the distribution.
|
|
|
|
#
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
|
|
# contributors may be used to endorse or promote products derived from
|
|
|
|
# this software without specific prior written permission.
|
|
|
|
#
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
#
|
|
|
|
#################################################################################################
|
|
|
|
|
2023-04-07 08:44:58 +08:00
|
|
|
from typing import List
|
|
|
|
import torch
|
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
import tempfile
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_DTYPE_NAME = {
|
|
|
|
torch.float32: "f32",
|
|
|
|
torch.float16: "f16",
|
|
|
|
torch.bfloat16: "b16"
|
|
|
|
}
|
|
|
|
NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()}
|
|
|
|
|
|
|
|
def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor:
|
|
|
|
# PyTorch >= 2.0
|
|
|
|
if hasattr(tensor, 'untyped_storage'):
|
|
|
|
return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage())
|
|
|
|
return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped())
|
|
|
|
|
|
|
|
class PipedSubprocess:
|
|
|
|
def __init__(self, binary: str) -> None:
|
|
|
|
self.binary = binary
|
|
|
|
self.tempdir_ctx = tempfile.TemporaryDirectory()
|
|
|
|
|
|
|
|
def __enter__(self) -> "PipedSubprocess":
|
|
|
|
self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0)
|
|
|
|
self.tempdir = self.tempdir_ctx.__enter__()
|
|
|
|
self.file_counter = 0
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
|
|
self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
|
|
def temp_filename(self, suffix: str) -> str:
|
|
|
|
self.file_counter += 1
|
|
|
|
return os.path.join(self.tempdir, f"{self.file_counter}{suffix}")
|
|
|
|
|
|
|
|
def write(self, *args) -> None:
|
|
|
|
for a in args:
|
|
|
|
self.subp.stdin.write(str(a) + " ")
|
|
|
|
|
|
|
|
def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None:
|
|
|
|
print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}")
|
|
|
|
tensor_u8 = _tensor_from_storage(tensor, torch.uint8)
|
|
|
|
self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0])
|
|
|
|
filename = self.temp_filename(f"{name}.tensor")
|
|
|
|
assert tensor.storage_offset() == 0
|
|
|
|
with open(filename, "wb+") as fd:
|
|
|
|
fd.write(bytes(tensor_u8.numpy()))
|
|
|
|
self.write("file", filename)
|
|
|
|
self.write("tensor_end")
|
|
|
|
|
|
|
|
for stride_name, stride_value in zip(stride_names, tensor.stride()):
|
|
|
|
self.write(stride_name, stride_value)
|
|
|
|
|
|
|
|
def readTensor(self, name, stride_name, shape) -> torch.Tensor:
|
|
|
|
tmpfile = self.temp_filename(f"{name}.tensor")
|
|
|
|
self.write("tmpfile", tmpfile)
|
|
|
|
|
|
|
|
self.readExpect("tensor_begin")
|
|
|
|
dtype_str, name = self.read().split(":")
|
|
|
|
print(f"C++->Py : {dtype_str}:{name}")
|
|
|
|
u8len = int(self.read())
|
|
|
|
dtype = NAME_TORCH_DTYPE[dtype_str]
|
|
|
|
|
|
|
|
self.readExpect("file")
|
|
|
|
self.readExpect(tmpfile)
|
|
|
|
|
|
|
|
with open(tmpfile, "rb") as fd:
|
|
|
|
data = fd.read(u8len)
|
|
|
|
# `np.array` is not strictly needed, but avoids a torch warning
|
|
|
|
tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len)
|
|
|
|
self.readExpect("tensor_end")
|
|
|
|
|
|
|
|
tensor = _tensor_from_storage(tensor_u8, dtype)
|
|
|
|
strides = []
|
|
|
|
for sn in stride_name:
|
|
|
|
self.readExpect(sn)
|
|
|
|
strides.append(int(self.read()))
|
|
|
|
if len(strides) != shape:
|
|
|
|
strides.append(1)
|
|
|
|
assert len(strides) == len(shape), name
|
|
|
|
return torch.as_strided(tensor, shape, strides)
|
|
|
|
|
|
|
|
def readNamed(self, name: str):
|
|
|
|
self.readExpect(name)
|
|
|
|
return self.read()
|
|
|
|
|
|
|
|
def readExpect(self, what: str) -> None:
|
|
|
|
r = self.read()
|
|
|
|
if r != what:
|
|
|
|
raise ValueError(f"Read {r} but expected {what}")
|
|
|
|
|
|
|
|
def read(self):
|
|
|
|
read_all = []
|
|
|
|
# Skip initial whitespace
|
|
|
|
while True:
|
|
|
|
r = self.subp.stdout.read(1)
|
|
|
|
if r not in [' ', "\n"]:
|
|
|
|
read_all.append(r)
|
|
|
|
break
|
|
|
|
# Read data
|
|
|
|
while True:
|
|
|
|
r = self.subp.stdout.read(1)
|
|
|
|
if r in [' ', "\n"]:
|
|
|
|
break
|
|
|
|
read_all.append(r)
|
|
|
|
return ''.join(read_all)
|
|
|
|
|