cutlass/tools/profiler/src/device_allocation.cu

1505 lines
47 KiB
Plaintext
Raw Normal View History

/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Execution environment
*/
#include <cstring>
#include "cutlass/numeric_types.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/library/util.h"
#include "device_allocation.h"
namespace cutlass {
namespace profiler {
/////////////////////////////////////////////////////////////////////////////////////////////////
size_t DeviceAllocation::bytes(library::NumericTypeID type, size_t capacity) {
return size_t(cutlass::library::sizeof_bits(type)) * capacity / 8;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Layout>
static std::vector<int> get_packed_layout_stride(std::vector<int> const &extent) {
typename Layout::TensorCoord extent_coord;
typename Layout::Stride stride_coord;
if (extent.size() != size_t(Layout::kRank)) {
throw std::runtime_error("Layout does not have same rank as extent vector.");
}
for (int i = 0; i < Layout::kRank; ++i) {
extent_coord[i] = extent.at(i);
}
std::vector<int> stride;
stride.resize(Layout::kStrideRank, 0);
Layout layout = Layout::packed(extent_coord);
stride_coord = layout.stride();
for (int i = 0; i < Layout::kStrideRank; ++i) {
stride.at(i) = stride_coord[i];
}
return stride;
}
/// Returns the stride of a packed layout
std::vector<int> DeviceAllocation::get_packed_layout(
library::LayoutTypeID layout_id,
std::vector<int> const &extent) {
std::vector<int> stride;
switch (layout_id) {
case library::LayoutTypeID::kColumnMajor:
stride = get_packed_layout_stride<cutlass::layout::ColumnMajor>(extent);
break;
case library::LayoutTypeID::kRowMajor:
stride = get_packed_layout_stride<cutlass::layout::RowMajor>(extent);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK2:
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<2>>(extent);
break;
case library::LayoutTypeID::kRowMajorInterleavedK2:
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<2>>(extent);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK4:
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<4>>(extent);
break;
case library::LayoutTypeID::kRowMajorInterleavedK4:
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<4>>(extent);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK16:
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<16>>(extent);
break;
case library::LayoutTypeID::kRowMajorInterleavedK16:
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<16>>(extent);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK32:
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<32>>(extent);
break;
case library::LayoutTypeID::kRowMajorInterleavedK32:
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<32>>(extent);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK64:
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<64>>(extent);
break;
case library::LayoutTypeID::kRowMajorInterleavedK64:
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<64>>(extent);
break;
case library::LayoutTypeID::kTensorNCHW:
stride = get_packed_layout_stride<cutlass::layout::TensorNCHW>(extent);
break;
case library::LayoutTypeID::kTensorNHWC:
stride = get_packed_layout_stride<cutlass::layout::TensorNHWC>(extent);
break;
case library::LayoutTypeID::kTensorNDHWC:
stride = get_packed_layout_stride<cutlass::layout::TensorNDHWC>(extent);
break;
case library::LayoutTypeID::kTensorNC32HW32:
stride = get_packed_layout_stride<cutlass::layout::TensorNCxHWx<32>>(extent);
break;
case library::LayoutTypeID::kTensorNC64HW64:
stride = get_packed_layout_stride<cutlass::layout::TensorNCxHWx<64>>(extent);
break;
case library::LayoutTypeID::kTensorC32RSK32:
stride = get_packed_layout_stride<cutlass::layout::TensorCxRSKx<32>>(extent);
break;
case library::LayoutTypeID::kTensorC64RSK64:
stride = get_packed_layout_stride<cutlass::layout::TensorCxRSKx<64>>(extent);
break;
default: break;
}
return stride;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Template to use CUTLASS Layout functions to
template <typename Layout>
static size_t construct_layout_(
void *bytes,
library::LayoutTypeID layout_id,
std::vector<int> const &extent,
std::vector<int> &stride) {
if (extent.size() != Layout::kRank) {
throw std::runtime_error(
"Layout must have same rank as extent vector.");
}
if (Layout::kStrideRank && stride.empty()) {
stride = get_packed_layout_stride<Layout>(extent);
return construct_layout_<Layout>(
bytes,
layout_id,
extent,
stride);
}
else if (Layout::kStrideRank && stride.size() != Layout::kStrideRank) {
throw std::runtime_error(
"Layout requires either empty stride or stride vector matching Layout::kStrideRank");
}
typename Layout::Stride stride_coord;
for (int i = 0; i < Layout::kStrideRank; ++i) {
stride_coord[i] = stride.at(i);
}
typename Layout::TensorCoord extent_coord;
for (int i = 0; i < Layout::kRank; ++i) {
extent_coord[i] = extent.at(i);
}
// Construct the CUTLASS layout object from the stride object
Layout layout(stride_coord);
// Pack it into bytes
if (bytes) {
*reinterpret_cast<Layout *>(bytes) = layout;
}
// Return capacity
size_t capacity_ = layout.capacity(extent_coord);
return capacity_;
}
/// returns the capacity needed
size_t DeviceAllocation::construct_layout(
void *bytes,
library::LayoutTypeID layout_id,
std::vector<int> const &extent,
std::vector<int> &stride) {
switch (layout_id) {
case library::LayoutTypeID::kColumnMajor:
return construct_layout_<cutlass::layout::ColumnMajor>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kRowMajor:
return construct_layout_<cutlass::layout::RowMajor>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kColumnMajorInterleavedK2:
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<2>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kRowMajorInterleavedK2:
return construct_layout_<cutlass::layout::RowMajorInterleaved<2>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kColumnMajorInterleavedK4:
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<4>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kRowMajorInterleavedK4:
return construct_layout_<cutlass::layout::RowMajorInterleaved<4>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kColumnMajorInterleavedK16:
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<16>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kRowMajorInterleavedK16:
return construct_layout_<cutlass::layout::RowMajorInterleaved<16>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kColumnMajorInterleavedK32:
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<32>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kRowMajorInterleavedK32:
return construct_layout_<cutlass::layout::RowMajorInterleaved<32>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kColumnMajorInterleavedK64:
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<64>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kRowMajorInterleavedK64:
return construct_layout_<cutlass::layout::RowMajorInterleaved<64>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kTensorNCHW:
return construct_layout_<cutlass::layout::TensorNHWC>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kTensorNHWC:
return construct_layout_<cutlass::layout::TensorNHWC>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kTensorNDHWC:
return construct_layout_<cutlass::layout::TensorNDHWC>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kTensorNC32HW32:
return construct_layout_<cutlass::layout::TensorNCxHWx<32>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kTensorNC64HW64:
return construct_layout_<cutlass::layout::TensorNCxHWx<64>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kTensorC32RSK32:
return construct_layout_<cutlass::layout::TensorCxRSKx<32>>(bytes, layout_id, extent, stride);
case library::LayoutTypeID::kTensorC64RSK64:
return construct_layout_<cutlass::layout::TensorCxRSKx<64>>(bytes, layout_id, extent, stride);
default: break;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
DeviceAllocation::DeviceAllocation():
type_(library::NumericTypeID::kInvalid),
batch_stride_(0),
capacity_(0),
pointer_(nullptr),
layout_(library::LayoutTypeID::kUnknown),
batch_count_(1) {
}
DeviceAllocation::DeviceAllocation(
library::NumericTypeID type,
size_t capacity
):
type_(type), batch_stride_(capacity), capacity_(capacity), pointer_(nullptr),
layout_(library::LayoutTypeID::kUnknown), batch_count_(1) {
cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity));
if (result != cudaSuccess) {
type_ = library::NumericTypeID::kInvalid;
capacity_ = 0;
pointer_ = nullptr;
throw std::bad_alloc();
}
}
DeviceAllocation::DeviceAllocation(
library::NumericTypeID type,
library::LayoutTypeID layout_id,
std::vector<int> const &extent,
std::vector<int> const &stride,
int batch_count
):
type_(type), batch_stride_(size_t(0)), capacity_(size_t(0)), pointer_(nullptr), batch_count_(1) {
reset(type, layout_id, extent, stride, batch_count);
}
DeviceAllocation::~DeviceAllocation() {
if (pointer_) {
cudaFree(pointer_);
}
}
DeviceAllocation &DeviceAllocation::reset() {
if (pointer_) {
cudaFree(pointer_);
}
type_ = library::NumericTypeID::kInvalid;
batch_stride_ = 0;
capacity_ = 0;
pointer_ = nullptr;
layout_ = library::LayoutTypeID::kUnknown;
stride_.clear();
extent_.clear();
tensor_ref_buffer_.clear();
batch_count_ = 1;
return *this;
}
DeviceAllocation &DeviceAllocation::reset(library::NumericTypeID type, size_t capacity) {
reset();
type_ = type;
batch_stride_ = capacity;
capacity_ = capacity;
cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type_, capacity_));
if (result != cudaSuccess) {
throw std::bad_alloc();
}
layout_ = library::LayoutTypeID::kUnknown;
stride_.clear();
extent_.clear();
batch_count_ = 1;
tensor_ref_buffer_.resize(sizeof(pointer_), 0);
std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_));
return *this;
}
/// Allocates memory for a given layout and tensor
DeviceAllocation &DeviceAllocation::reset(
library::NumericTypeID type,
library::LayoutTypeID layout_id,
std::vector<int> const &extent,
std::vector<int> const &stride,
int batch_count) {
reset();
tensor_ref_buffer_.resize(sizeof(pointer_) + (sizeof(int) * library::get_layout_stride_rank(layout_id)), 0);
type_ = type;
layout_ = layout_id;
stride_ = stride;
extent_ = extent;
batch_count_ = batch_count;
batch_stride_ = construct_layout(
tensor_ref_buffer_.data() + sizeof(pointer_),
layout_id,
extent,
stride_);
capacity_ = batch_stride_ * batch_count_;
cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity_));
if (result != cudaSuccess) {
throw std::bad_alloc();
}
std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_));
return *this;
}
bool DeviceAllocation::good() const {
return (capacity_ && pointer_);
}
library::NumericTypeID DeviceAllocation::type() const {
return type_;
}
void *DeviceAllocation::data() const {
return pointer_;
}
void *DeviceAllocation::batch_data(int batch_idx) const {
return static_cast<char *>(data()) + batch_stride_bytes() * batch_idx;
}
library::LayoutTypeID DeviceAllocation::layout() const {
return layout_;
}
std::vector<int> const & DeviceAllocation::stride() const {
return stride_;
}
/// Gets the extent vector
std::vector<int> const & DeviceAllocation::extent() const {
return extent_;
}
/// Gets the number of adjacent tensors in memory
int DeviceAllocation::batch_count() const {
return batch_count_;
}
/// Gets the stride (in units of elements) beteween items
int64_t DeviceAllocation::batch_stride() const {
return batch_stride_;
}
/// Gets the stride (in units of bytes) beteween items
int64_t DeviceAllocation::batch_stride_bytes() const {
return bytes(type_, batch_stride_);
}
size_t DeviceAllocation::capacity() const {
return capacity_;
}
size_t DeviceAllocation::bytes() const {
return bytes(type_, capacity_);
}
/// Copies from an equivalent-sized tensor in device memory
void DeviceAllocation::copy_from_device(void const *ptr) {
cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyDeviceToDevice);
if (result != cudaSuccess) {
throw std::runtime_error("Failed device-to-device copy");
}
}
/// Copies from an equivalent-sized tensor in device memory
void DeviceAllocation::copy_from_host(void const *ptr) {
cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
throw std::runtime_error("Failed device-to-device copy");
}
}
/// Copies from an equivalent-sized tensor in device memory
void DeviceAllocation::copy_to_host(void *ptr) {
cudaError_t result = cudaMemcpy(ptr, data(), bytes(), cudaMemcpyDeviceToHost);
if (result != cudaSuccess) {
throw std::runtime_error("Failed device-to-device copy");
}
}
void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
if (!good()) {
throw std::runtime_error("Attempting to initialize invalid allocation.");
}
// Instantiate calls to CURAND here. This file takes a long time to compile for
// this reason.
switch (type_) {
case library::NumericTypeID::kF16:
cutlass::reference::device::BlockFillRandom<cutlass::half_t>(
reinterpret_cast<cutlass::half_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kBF16:
cutlass::reference::device::BlockFillRandom<cutlass::bfloat16_t>(
reinterpret_cast<cutlass::bfloat16_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kTF32:
cutlass::reference::device::BlockFillRandom<cutlass::tfloat32_t>(
reinterpret_cast<cutlass::tfloat32_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kF32:
cutlass::reference::device::BlockFillRandom<float>(
reinterpret_cast<float *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCBF16:
cutlass::reference::device::BlockFillRandom<complex<bfloat16_t>>(
reinterpret_cast<complex<bfloat16_t> *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCTF32:
cutlass::reference::device::BlockFillRandom<cutlass::complex<cutlass::tfloat32_t>>(
reinterpret_cast<cutlass::complex<cutlass::tfloat32_t> *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCF32:
cutlass::reference::device::BlockFillRandom<cutlass::complex<float>>(
reinterpret_cast<cutlass::complex<float> *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kF64:
cutlass::reference::device::BlockFillRandom<double>(
reinterpret_cast<double *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCF64:
cutlass::reference::device::BlockFillRandom<complex<double>>(
reinterpret_cast<complex<double> *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS2:
cutlass::reference::device::BlockFillRandom<int2b_t>(
reinterpret_cast<int2b_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS4:
cutlass::reference::device::BlockFillRandom<int4b_t>(
reinterpret_cast<int4b_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS8:
cutlass::reference::device::BlockFillRandom<int8_t>(
reinterpret_cast<int8_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS16:
cutlass::reference::device::BlockFillRandom<int16_t>(
reinterpret_cast<int16_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS32:
cutlass::reference::device::BlockFillRandom<int32_t>(
reinterpret_cast<int32_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS64:
cutlass::reference::device::BlockFillRandom<int64_t>(
reinterpret_cast<int64_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kB1:
cutlass::reference::device::BlockFillRandom<uint1b_t>(
reinterpret_cast<uint1b_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU2:
cutlass::reference::device::BlockFillRandom<uint2b_t>(
reinterpret_cast<uint2b_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU4:
cutlass::reference::device::BlockFillRandom<uint4b_t>(
reinterpret_cast<uint4b_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU8:
cutlass::reference::device::BlockFillRandom<uint8_t>(
reinterpret_cast<uint8_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU16:
cutlass::reference::device::BlockFillRandom<uint16_t>(
reinterpret_cast<uint16_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU32:
cutlass::reference::device::BlockFillRandom<uint32_t>(
reinterpret_cast<uint32_t *>(pointer_),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU64:
cutlass::reference::device::BlockFillRandom<uint64_t>(
reinterpret_cast<uint64_t *>(pointer_),
capacity_,
seed,
dist
);
break;
default: break;
}
}
void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
if (!good()) {
throw std::runtime_error("Attempting to initialize invalid allocation.");
}
std::vector<uint8_t> host_data(bytes());
switch (type_) {
case library::NumericTypeID::kF16:
cutlass::reference::host::BlockFillRandom<cutlass::half_t>(
reinterpret_cast<cutlass::half_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kBF16:
cutlass::reference::host::BlockFillRandom<cutlass::bfloat16_t>(
reinterpret_cast<cutlass::bfloat16_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kTF32:
cutlass::reference::host::BlockFillRandom<cutlass::tfloat32_t>(
reinterpret_cast<cutlass::tfloat32_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kF32:
cutlass::reference::host::BlockFillRandom<float>(
reinterpret_cast<float *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCF16:
cutlass::reference::host::BlockFillRandom<cutlass::complex<cutlass::half_t>>(
reinterpret_cast<cutlass::complex<cutlass::half_t> *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCBF16:
cutlass::reference::host::BlockFillRandom<cutlass::complex<cutlass::bfloat16_t>>(
reinterpret_cast<cutlass::complex<cutlass::bfloat16_t> *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCTF32:
cutlass::reference::host::BlockFillRandom<cutlass::complex<cutlass::tfloat32_t>>(
reinterpret_cast<cutlass::complex<cutlass::tfloat32_t> *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCF32:
cutlass::reference::host::BlockFillRandom<cutlass::complex<float>>(
reinterpret_cast<cutlass::complex<float> *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kF64:
cutlass::reference::host::BlockFillRandom<double>(
reinterpret_cast<double *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kCF64:
cutlass::reference::host::BlockFillRandom<cutlass::complex<double>>(
reinterpret_cast<cutlass::complex<double> *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS2:
cutlass::reference::host::BlockFillRandom<int2b_t>(
reinterpret_cast<int2b_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS4:
cutlass::reference::host::BlockFillRandom<int4b_t>(
reinterpret_cast<int4b_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS8:
cutlass::reference::host::BlockFillRandom<int8_t>(
reinterpret_cast<int8_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS16:
cutlass::reference::host::BlockFillRandom<int16_t>(
reinterpret_cast<int16_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS32:
cutlass::reference::host::BlockFillRandom<int32_t>(
reinterpret_cast<int32_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kS64:
cutlass::reference::host::BlockFillRandom<int64_t>(
reinterpret_cast<int64_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kB1:
cutlass::reference::host::BlockFillRandom<uint1b_t>(
reinterpret_cast<uint1b_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU2:
cutlass::reference::host::BlockFillRandom<uint2b_t>(
reinterpret_cast<uint2b_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU4:
cutlass::reference::host::BlockFillRandom<uint4b_t>(
reinterpret_cast<uint4b_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU8:
cutlass::reference::host::BlockFillRandom<uint8_t>(
reinterpret_cast<uint8_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU16:
cutlass::reference::host::BlockFillRandom<uint16_t>(
reinterpret_cast<uint16_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU32:
cutlass::reference::host::BlockFillRandom<uint32_t>(
reinterpret_cast<uint32_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
case library::NumericTypeID::kU64:
cutlass::reference::host::BlockFillRandom<uint64_t>(
reinterpret_cast<uint64_t *>(host_data.data()),
capacity_,
seed,
dist
);
break;
default: break;
}
copy_from_host(host_data.data());
}
void DeviceAllocation::initialize_random_sparsemeta_device(int seed, int MetaSizeInBits) {
if (!good()) {
throw std::runtime_error("Attempting to initialize invalid allocation.");
}
// Instantiate calls to CURAND here. This file takes a long time to compile for
// this reason.
switch (type_) {
case library::NumericTypeID::kU16:
cutlass::reference::device::BlockFillRandomSparseMeta<uint16_t>(
reinterpret_cast<uint16_t *>(pointer_),
capacity_,
seed,
MetaSizeInBits
);
break;
case library::NumericTypeID::kU32:
cutlass::reference::device::BlockFillRandomSparseMeta<uint32_t>(
reinterpret_cast<uint32_t *>(pointer_),
capacity_,
seed,
MetaSizeInBits
);
break;
default:
break;
}
}
void DeviceAllocation::initialize_random_sparsemeta_host(int seed, int MetaSizeInBits) {
if (!good()) {
throw std::runtime_error("Attempting to initialize invalid allocation.");
}
std::vector<uint8_t> host_data(bytes());
switch (type_) {
case library::NumericTypeID::kS16:
cutlass::reference::host::BlockFillRandomSparseMeta<uint16_t>(
reinterpret_cast<uint16_t *>(host_data.data()),
capacity_,
seed,
MetaSizeInBits
);
break;
case library::NumericTypeID::kS32:
cutlass::reference::host::BlockFillRandomSparseMeta<uint32_t>(
reinterpret_cast<uint32_t *>(host_data.data()),
capacity_,
seed,
MetaSizeInBits
);
break;
default:
break;
}
copy_from_host(host_data.data());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Returns true if two blocks have exactly the same value
bool DeviceAllocation::block_compare_equal(
library::NumericTypeID numeric_type,
void const *ptr_A,
void const *ptr_B,
size_t capacity) {
switch (numeric_type) {
case library::NumericTypeID::kF16:
return reference::device::BlockCompareEqual<half_t>(
reinterpret_cast<half_t const *>(ptr_A),
reinterpret_cast<half_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kBF16:
return reference::device::BlockCompareEqual<bfloat16_t>(
reinterpret_cast<bfloat16_t const *>(ptr_A),
reinterpret_cast<bfloat16_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kTF32:
return reference::device::BlockCompareEqual<tfloat32_t>(
reinterpret_cast<tfloat32_t const *>(ptr_A),
reinterpret_cast<tfloat32_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kF32:
return reference::device::BlockCompareEqual<float>(
reinterpret_cast<float const *>(ptr_A),
reinterpret_cast<float const *>(ptr_B),
capacity);
case library::NumericTypeID::kCF32:
return reference::device::BlockCompareEqual<cutlass::complex<float> >(
reinterpret_cast<complex<float> const *>(ptr_A),
reinterpret_cast<complex<float> const *>(ptr_B),
capacity);
case library::NumericTypeID::kCF16:
return reference::device::BlockCompareEqual<complex<half_t>>(
reinterpret_cast<complex<half_t> const *>(ptr_A),
reinterpret_cast<complex<half_t> const *>(ptr_B),
capacity);
case library::NumericTypeID::kCBF16:
return reference::device::BlockCompareEqual<complex<bfloat16_t>>(
reinterpret_cast<complex<bfloat16_t> const *>(ptr_A),
reinterpret_cast<complex<bfloat16_t> const *>(ptr_B),
capacity);
case library::NumericTypeID::kCTF32:
return reference::device::BlockCompareEqual<complex<tfloat32_t>>(
reinterpret_cast<complex<tfloat32_t> const *>(ptr_A),
reinterpret_cast<complex<tfloat32_t> const *>(ptr_B),
capacity);
case library::NumericTypeID::kF64:
return reference::device::BlockCompareEqual<double>(
reinterpret_cast<double const *>(ptr_A),
reinterpret_cast<double const *>(ptr_B),
capacity);
case library::NumericTypeID::kCF64:
return reference::device::BlockCompareEqual<complex<double>>(
reinterpret_cast<complex<double> const *>(ptr_A),
reinterpret_cast<complex<double> const *>(ptr_B),
capacity);
case library::NumericTypeID::kS2:
return reference::device::BlockCompareEqual<int2b_t>(
reinterpret_cast<int2b_t const *>(ptr_A),
reinterpret_cast<int2b_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kS4:
return reference::device::BlockCompareEqual<int4b_t>(
reinterpret_cast<int4b_t const *>(ptr_A),
reinterpret_cast<int4b_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kS8:
return reference::device::BlockCompareEqual<int8_t>(
reinterpret_cast<int8_t const *>(ptr_A),
reinterpret_cast<int8_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kS16:
return reference::device::BlockCompareEqual<int16_t>(
reinterpret_cast<int16_t const *>(ptr_A),
reinterpret_cast<int16_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kS32:
return reference::device::BlockCompareEqual<int32_t>(
reinterpret_cast<int32_t const *>(ptr_A),
reinterpret_cast<int32_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kS64:
return reference::device::BlockCompareEqual<int64_t>(
reinterpret_cast<int64_t const *>(ptr_A),
reinterpret_cast<int64_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kB1:
return reference::device::BlockCompareEqual<uint1b_t>(
reinterpret_cast<uint1b_t const *>(ptr_A),
reinterpret_cast<uint1b_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kU2:
return reference::device::BlockCompareEqual<uint2b_t>(
reinterpret_cast<uint2b_t const *>(ptr_A),
reinterpret_cast<uint2b_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kU4:
return reference::device::BlockCompareEqual<uint4b_t>(
reinterpret_cast<uint4b_t const *>(ptr_A),
reinterpret_cast<uint4b_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kU8:
return reference::device::BlockCompareEqual<uint8_t>(
reinterpret_cast<uint8_t const *>(ptr_A),
reinterpret_cast<uint8_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kU16:
return reference::device::BlockCompareEqual<uint16_t>(
reinterpret_cast<uint16_t const *>(ptr_A),
reinterpret_cast<uint16_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kU32:
return reference::device::BlockCompareEqual<uint32_t>(
reinterpret_cast<uint32_t const *>(ptr_A),
reinterpret_cast<uint32_t const *>(ptr_B),
capacity);
case library::NumericTypeID::kU64:
return reference::device::BlockCompareEqual<uint64_t>(
reinterpret_cast<uint64_t const *>(ptr_A),
reinterpret_cast<uint64_t const *>(ptr_B),
capacity);
default:
throw std::runtime_error("Unsupported numeric type");
}
}
/// Returns true if two blocks have approximately the same value
bool DeviceAllocation::block_compare_relatively_equal(
library::NumericTypeID numeric_type,
void const *ptr_A,
void const *ptr_B,
size_t capacity,
double epsilon,
double nonzero_floor) {
switch (numeric_type) {
case library::NumericTypeID::kF16:
return reference::device::BlockCompareRelativelyEqual<half_t>(
reinterpret_cast<half_t const *>(ptr_A),
reinterpret_cast<half_t const *>(ptr_B),
capacity,
static_cast<half_t>(epsilon),
static_cast<half_t>(nonzero_floor));
case library::NumericTypeID::kBF16:
return reference::device::BlockCompareRelativelyEqual<bfloat16_t>(
reinterpret_cast<bfloat16_t const *>(ptr_A),
reinterpret_cast<bfloat16_t const *>(ptr_B),
capacity,
static_cast<bfloat16_t>(epsilon),
static_cast<bfloat16_t>(nonzero_floor));
case library::NumericTypeID::kTF32:
return reference::device::BlockCompareRelativelyEqual<tfloat32_t>(
reinterpret_cast<tfloat32_t const *>(ptr_A),
reinterpret_cast<tfloat32_t const *>(ptr_B),
capacity,
static_cast<tfloat32_t>(epsilon),
static_cast<tfloat32_t>(nonzero_floor));
case library::NumericTypeID::kF32:
return reference::device::BlockCompareRelativelyEqual<float>(
reinterpret_cast<float const *>(ptr_A),
reinterpret_cast<float const *>(ptr_B),
capacity,
static_cast<float>(epsilon),
static_cast<float>(nonzero_floor));
case library::NumericTypeID::kF64:
return reference::device::BlockCompareRelativelyEqual<double>(
reinterpret_cast<double const *>(ptr_A),
reinterpret_cast<double const *>(ptr_B),
capacity,
static_cast<double>(epsilon),
static_cast<double>(nonzero_floor));
case library::NumericTypeID::kS2:
return reference::device::BlockCompareRelativelyEqual<int2b_t>(
reinterpret_cast<int2b_t const *>(ptr_A),
reinterpret_cast<int2b_t const *>(ptr_B),
capacity,
static_cast<int2b_t>(epsilon),
static_cast<int2b_t>(nonzero_floor));
case library::NumericTypeID::kS4:
return reference::device::BlockCompareRelativelyEqual<int4b_t>(
reinterpret_cast<int4b_t const *>(ptr_A),
reinterpret_cast<int4b_t const *>(ptr_B),
capacity,
static_cast<int4b_t>(epsilon),
static_cast<int4b_t>(nonzero_floor));
case library::NumericTypeID::kS8:
return reference::device::BlockCompareRelativelyEqual<int8_t>(
reinterpret_cast<int8_t const *>(ptr_A),
reinterpret_cast<int8_t const *>(ptr_B),
capacity,
static_cast<int8_t>(epsilon),
static_cast<int8_t>(nonzero_floor));
case library::NumericTypeID::kS16:
return reference::device::BlockCompareRelativelyEqual<int16_t>(
reinterpret_cast<int16_t const *>(ptr_A),
reinterpret_cast<int16_t const *>(ptr_B),
capacity,
static_cast<int16_t>(epsilon),
static_cast<int16_t>(nonzero_floor));
case library::NumericTypeID::kS32:
return reference::device::BlockCompareRelativelyEqual<int32_t>(
reinterpret_cast<int32_t const *>(ptr_A),
reinterpret_cast<int32_t const *>(ptr_B),
capacity,
static_cast<int32_t>(epsilon),
static_cast<int32_t>(nonzero_floor));
case library::NumericTypeID::kS64:
return reference::device::BlockCompareRelativelyEqual<int64_t>(
reinterpret_cast<int64_t const *>(ptr_A),
reinterpret_cast<int64_t const *>(ptr_B),
capacity,
static_cast<int64_t>(epsilon),
static_cast<int64_t>(nonzero_floor));
case library::NumericTypeID::kB1:
return reference::device::BlockCompareRelativelyEqual<uint1b_t>(
reinterpret_cast<uint1b_t const *>(ptr_A),
reinterpret_cast<uint1b_t const *>(ptr_B),
capacity,
static_cast<uint1b_t>(epsilon),
static_cast<uint1b_t>(nonzero_floor));
case library::NumericTypeID::kU2:
return reference::device::BlockCompareRelativelyEqual<uint2b_t>(
reinterpret_cast<uint2b_t const *>(ptr_A),
reinterpret_cast<uint2b_t const *>(ptr_B),
capacity,
static_cast<uint2b_t>(epsilon),
static_cast<uint2b_t>(nonzero_floor));
case library::NumericTypeID::kU4:
return reference::device::BlockCompareRelativelyEqual<uint4b_t>(
reinterpret_cast<uint4b_t const *>(ptr_A),
reinterpret_cast<uint4b_t const *>(ptr_B),
capacity,
static_cast<uint4b_t>(epsilon),
static_cast<uint4b_t>(nonzero_floor));
case library::NumericTypeID::kU8:
return reference::device::BlockCompareRelativelyEqual<uint8_t>(
reinterpret_cast<uint8_t const *>(ptr_A),
reinterpret_cast<uint8_t const *>(ptr_B),
capacity,
static_cast<uint8_t>(epsilon),
static_cast<uint8_t>(nonzero_floor));
case library::NumericTypeID::kU16:
return reference::device::BlockCompareRelativelyEqual<uint16_t>(
reinterpret_cast<uint16_t const *>(ptr_A),
reinterpret_cast<uint16_t const *>(ptr_B),
capacity,
static_cast<uint16_t>(epsilon),
static_cast<uint16_t>(nonzero_floor));
case library::NumericTypeID::kU32:
return reference::device::BlockCompareRelativelyEqual<uint32_t>(
reinterpret_cast<uint32_t const *>(ptr_A),
reinterpret_cast<uint32_t const *>(ptr_B),
capacity,
static_cast<uint32_t>(epsilon),
static_cast<uint32_t>(nonzero_floor));
case library::NumericTypeID::kU64:
return reference::device::BlockCompareRelativelyEqual<uint64_t>(
reinterpret_cast<uint64_t const *>(ptr_A),
reinterpret_cast<uint64_t const *>(ptr_B),
capacity,
static_cast<uint64_t>(epsilon),
static_cast<uint64_t>(nonzero_floor));
// No relatively equal comparison for complex numbers.
//
// As a simplification, we can require bitwise equality. This avoids false positives.
// (i.e. "pass" really means passing. "Fail" may not actually mean failure given appropriate epsilon.)
//
case library::NumericTypeID::kCF16:
return reference::device::BlockCompareEqual<cutlass::complex<half_t> >(
reinterpret_cast<complex<half_t> const *>(ptr_A),
reinterpret_cast<complex<half_t> const *>(ptr_B),
capacity);
case library::NumericTypeID::kCF32:
return reference::device::BlockCompareEqual<cutlass::complex<float> >(
reinterpret_cast<complex<float> const *>(ptr_A),
reinterpret_cast<complex<float> const *>(ptr_B),
capacity);
case library::NumericTypeID::kCF64:
return reference::device::BlockCompareEqual<cutlass::complex<double> >(
reinterpret_cast<complex<double> const *>(ptr_A),
reinterpret_cast<complex<double> const *>(ptr_B),
capacity);
default:
{
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type));
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord, int Rank>
struct vector_to_coord {
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
coord[Rank - 1] = vec.at(Rank - 1);
if (Rank > 1) {
vector_to_coord<TensorCoord, Rank - 1>(coord, vec);
}
}
};
/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord>
struct vector_to_coord<TensorCoord, 1> {
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
coord[0] = vec.at(0);
}
};
/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord>
struct vector_to_coord<TensorCoord, 0> {
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element, typename Layout>
static void write_tensor_csv_static_tensor_view(
std::ostream &out,
DeviceAllocation &allocation) {
Coord<Layout::kRank> extent;
Coord<Layout::kStrideRank> stride;
if (allocation.extent().size() != Layout::kRank) {
throw std::runtime_error("Allocation extent has invalid rank");
}
if (allocation.stride().size() != Layout::kStrideRank) {
throw std::runtime_error("Allocation stride has invalid rank");
}
vector_to_coord<Coord<Layout::kRank>, Layout::kRank>(extent, allocation.extent());
vector_to_coord<Coord<Layout::kStrideRank>, Layout::kStrideRank>(stride, allocation.stride());
Layout layout(stride);
HostTensor<Element, Layout> host_tensor(extent, layout, false);
if (host_tensor.capacity() != allocation.batch_stride()) {
throw std::runtime_error("Unexpected capacity to equal.");
}
host_tensor.copy_in_device_to_host(
static_cast<Element const *>(allocation.data()),
allocation.batch_stride());
TensorViewWrite(out, host_tensor.host_view());
out << "\n\n";
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void write_tensor_csv_static_type(
std::ostream &out,
DeviceAllocation &allocation) {
switch (allocation.layout()) {
case library::LayoutTypeID::kRowMajor:
write_tensor_csv_static_tensor_view<T, layout::RowMajor>(out, allocation);
break;
case library::LayoutTypeID::kColumnMajor:
write_tensor_csv_static_tensor_view<T, layout::ColumnMajor>(out, allocation);
break;
case library::LayoutTypeID::kRowMajorInterleavedK2:
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<2>>(out, allocation);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK2:
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<2>>(out, allocation);
break;
case library::LayoutTypeID::kRowMajorInterleavedK4:
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<4>>(out, allocation);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK4:
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<4>>(out, allocation);
break;
case library::LayoutTypeID::kRowMajorInterleavedK16:
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<16>>(out, allocation);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK16:
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<16>>(out, allocation);
break;
case library::LayoutTypeID::kRowMajorInterleavedK32:
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<32>>(out, allocation);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK32:
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<32>>(out, allocation);
break;
case library::LayoutTypeID::kRowMajorInterleavedK64:
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<64>>(out, allocation);
break;
case library::LayoutTypeID::kColumnMajorInterleavedK64:
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<64>>(out, allocation);
break;
case library::LayoutTypeID::kTensorNHWC:
write_tensor_csv_static_tensor_view<T, layout::TensorNHWC>(out, allocation);
break;
case library::LayoutTypeID::kTensorNDHWC:
write_tensor_csv_static_tensor_view<T, layout::TensorNDHWC>(out, allocation);
break;
case library::LayoutTypeID::kTensorNC32HW32:
write_tensor_csv_static_tensor_view<T, layout::TensorNCxHWx<32>>(out, allocation);
break;
case library::LayoutTypeID::kTensorNC64HW64:
write_tensor_csv_static_tensor_view<T, layout::TensorNCxHWx<64>>(out, allocation);
break;
case library::LayoutTypeID::kTensorC32RSK32:
write_tensor_csv_static_tensor_view<T, layout::TensorCxRSKx<32>>(out, allocation);
break;
case library::LayoutTypeID::kTensorC64RSK64:
write_tensor_csv_static_tensor_view<T, layout::TensorCxRSKx<64>>(out, allocation);
break;
default:
throw std::runtime_error("Unhandled layout");
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Writes a tensor to csv
void DeviceAllocation::write_tensor_csv(
std::ostream &out) {
switch (this->type()) {
case library::NumericTypeID::kF16:
write_tensor_csv_static_type<half_t>(out, *this);
break;
case library::NumericTypeID::kBF16:
write_tensor_csv_static_type<bfloat16_t>(out, *this);
break;
case library::NumericTypeID::kTF32:
write_tensor_csv_static_type<tfloat32_t>(out, *this);
break;
case library::NumericTypeID::kF32:
write_tensor_csv_static_type<float>(out, *this);
break;
case library::NumericTypeID::kF64:
write_tensor_csv_static_type<double>(out, *this);
break;
case library::NumericTypeID::kS2:
write_tensor_csv_static_type<int2b_t>(out, *this);
break;
case library::NumericTypeID::kS4:
write_tensor_csv_static_type<int4b_t>(out, *this);
break;
case library::NumericTypeID::kS8:
write_tensor_csv_static_type<int8_t>(out, *this);
break;
case library::NumericTypeID::kS16:
write_tensor_csv_static_type<int16_t>(out, *this);
break;
case library::NumericTypeID::kS32:
write_tensor_csv_static_type<int32_t>(out, *this);
break;
case library::NumericTypeID::kS64:
write_tensor_csv_static_type<int64_t>(out, *this);
break;
case library::NumericTypeID::kB1:
write_tensor_csv_static_type<uint1b_t>(out, *this);
break;
case library::NumericTypeID::kU2:
write_tensor_csv_static_type<uint2b_t>(out, *this);
break;
case library::NumericTypeID::kU4:
write_tensor_csv_static_type<uint4b_t>(out, *this);
break;
case library::NumericTypeID::kU8:
write_tensor_csv_static_type<uint8_t>(out, *this);
break;
case library::NumericTypeID::kU16:
write_tensor_csv_static_type<uint16_t>(out, *this);
break;
case library::NumericTypeID::kU32:
write_tensor_csv_static_type<uint32_t>(out, *this);
break;
case library::NumericTypeID::kU64:
write_tensor_csv_static_type<uint64_t>(out, *this);
break;
case library::NumericTypeID::kCF16:
write_tensor_csv_static_type<cutlass::complex<half_t> >(out, *this);
break;
case library::NumericTypeID::kCF32:
write_tensor_csv_static_type<cutlass::complex<float> >(out, *this);
break;
case library::NumericTypeID::kCF64:
write_tensor_csv_static_type<cutlass::complex<double> >(out, *this);
break;
default:
throw std::runtime_error("Unsupported numeric type");
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace profiler
} // namespace cutlass