1676 lines
52 KiB
Plaintext
1676 lines
52 KiB
Plaintext
/***************************************************************************************************
|
|
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
* provided that the following conditions are met:
|
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
|
* conditions and the following disclaimer.
|
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
* conditions and the following disclaimer in the documentation and/or other materials
|
|
* provided with the distribution.
|
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
* to endorse or promote products derived from this software without specific prior written
|
|
* permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/* \file
|
|
\brief Execution environment
|
|
*/
|
|
|
|
#include <cstring>
|
|
|
|
#include "cutlass/numeric_types.h"
|
|
#include "cutlass/layout/matrix.h"
|
|
#include "cutlass/layout/tensor.h"
|
|
|
|
#include "cutlass/util/reference/device/tensor_compare.h"
|
|
#include "cutlass/util/reference/device/tensor_fill.h"
|
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
|
#include "cutlass/util/host_tensor.h"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
|
|
#include "cutlass/library/util.h"
|
|
|
|
#include "device_allocation.h"
|
|
|
|
namespace cutlass {
|
|
namespace profiler {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
size_t DeviceAllocation::bytes(library::NumericTypeID type, size_t capacity) {
|
|
return size_t(cutlass::library::sizeof_bits(type)) * capacity / 8;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Layout>
|
|
static std::vector<int64_t> get_packed_layout_stride(std::vector<int> const &extent) {
|
|
|
|
typename Layout::TensorCoord extent_coord;
|
|
typename Layout::Stride stride_coord;
|
|
|
|
if (extent.size() != size_t(Layout::kRank)) {
|
|
throw std::runtime_error("Layout does not have same rank as extent vector.");
|
|
}
|
|
|
|
for (int i = 0; i < Layout::kRank; ++i) {
|
|
extent_coord[i] = extent.at(i);
|
|
}
|
|
|
|
std::vector<int64_t> stride;
|
|
stride.resize(Layout::kStrideRank, 0);
|
|
|
|
Layout layout = Layout::packed(extent_coord);
|
|
stride_coord = layout.stride();
|
|
|
|
for (int i = 0; i < Layout::kStrideRank; ++i) {
|
|
stride.at(i) = (int64_t)stride_coord[i];
|
|
}
|
|
|
|
return stride;
|
|
}
|
|
|
|
/// Returns the stride of a packed layout
|
|
std::vector<int64_t> DeviceAllocation::get_packed_layout(
|
|
library::LayoutTypeID layout_id,
|
|
std::vector<int> const &extent) {
|
|
|
|
std::vector<int64_t> stride;
|
|
|
|
switch (layout_id) {
|
|
case library::LayoutTypeID::kColumnMajor:
|
|
stride = get_packed_layout_stride<cutlass::layout::ColumnMajor>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajor:
|
|
stride = get_packed_layout_stride<cutlass::layout::RowMajor>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK2:
|
|
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<2>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK2:
|
|
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<2>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK4:
|
|
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<4>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK4:
|
|
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<4>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK16:
|
|
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<16>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK16:
|
|
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<16>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK32:
|
|
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<32>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK32:
|
|
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<32>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK64:
|
|
stride = get_packed_layout_stride<cutlass::layout::ColumnMajorInterleaved<64>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK64:
|
|
stride = get_packed_layout_stride<cutlass::layout::RowMajorInterleaved<64>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNCHW:
|
|
stride = get_packed_layout_stride<cutlass::layout::TensorNCHW>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNHWC:
|
|
stride = get_packed_layout_stride<cutlass::layout::TensorNHWC>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNDHWC:
|
|
stride = get_packed_layout_stride<cutlass::layout::TensorNDHWC>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNC32HW32:
|
|
stride = get_packed_layout_stride<cutlass::layout::TensorNCxHWx<32>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNC64HW64:
|
|
stride = get_packed_layout_stride<cutlass::layout::TensorNCxHWx<64>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kTensorC32RSK32:
|
|
stride = get_packed_layout_stride<cutlass::layout::TensorCxRSKx<32>>(extent);
|
|
break;
|
|
case library::LayoutTypeID::kTensorC64RSK64:
|
|
stride = get_packed_layout_stride<cutlass::layout::TensorCxRSKx<64>>(extent);
|
|
break;
|
|
default: break;
|
|
}
|
|
|
|
return stride;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Template to use CUTLASS Layout functions to
|
|
template <typename Layout>
|
|
static size_t construct_layout_(
|
|
void *bytes,
|
|
library::LayoutTypeID layout_id,
|
|
std::vector<int> const &extent,
|
|
std::vector<int64_t> &stride) {
|
|
|
|
if (extent.size() != Layout::kRank) {
|
|
throw std::runtime_error(
|
|
"Layout must have same rank as extent vector.");
|
|
}
|
|
|
|
if (Layout::kStrideRank && stride.empty()) {
|
|
|
|
stride = get_packed_layout_stride<Layout>(extent);
|
|
|
|
return construct_layout_<Layout>(
|
|
bytes,
|
|
layout_id,
|
|
extent,
|
|
stride);
|
|
}
|
|
else if (Layout::kStrideRank && stride.size() != Layout::kStrideRank) {
|
|
throw std::runtime_error(
|
|
"Layout requires either empty stride or stride vector matching Layout::kStrideRank");
|
|
}
|
|
|
|
typename Layout::Stride stride_coord;
|
|
for (int i = 0; i < Layout::kStrideRank; ++i) {
|
|
stride_coord[i] = (int)stride.at(i);
|
|
}
|
|
|
|
typename Layout::TensorCoord extent_coord;
|
|
for (int i = 0; i < Layout::kRank; ++i) {
|
|
extent_coord[i] = extent.at(i);
|
|
}
|
|
|
|
// Construct the CUTLASS layout object from the stride object
|
|
Layout layout(stride_coord);
|
|
|
|
// Pack it into bytes
|
|
if (bytes) {
|
|
*reinterpret_cast<Layout *>(bytes) = layout;
|
|
}
|
|
|
|
// Return capacity
|
|
size_t capacity_ = layout.capacity(extent_coord);
|
|
|
|
return capacity_;
|
|
}
|
|
|
|
/// returns the capacity needed
|
|
size_t DeviceAllocation::construct_layout(
|
|
void *bytes,
|
|
library::LayoutTypeID layout_id,
|
|
std::vector<int> const &extent,
|
|
std::vector<int64_t> &stride) {
|
|
|
|
switch (layout_id) {
|
|
case library::LayoutTypeID::kColumnMajor:
|
|
return construct_layout_<cutlass::layout::ColumnMajor>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kRowMajor:
|
|
return construct_layout_<cutlass::layout::RowMajor>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK2:
|
|
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<2>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kRowMajorInterleavedK2:
|
|
return construct_layout_<cutlass::layout::RowMajorInterleaved<2>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK4:
|
|
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<4>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kRowMajorInterleavedK4:
|
|
return construct_layout_<cutlass::layout::RowMajorInterleaved<4>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK16:
|
|
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<16>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kRowMajorInterleavedK16:
|
|
return construct_layout_<cutlass::layout::RowMajorInterleaved<16>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK32:
|
|
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<32>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kRowMajorInterleavedK32:
|
|
return construct_layout_<cutlass::layout::RowMajorInterleaved<32>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK64:
|
|
return construct_layout_<cutlass::layout::ColumnMajorInterleaved<64>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kRowMajorInterleavedK64:
|
|
return construct_layout_<cutlass::layout::RowMajorInterleaved<64>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kTensorNCHW:
|
|
return construct_layout_<cutlass::layout::TensorNHWC>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kTensorNHWC:
|
|
return construct_layout_<cutlass::layout::TensorNHWC>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kTensorNDHWC:
|
|
return construct_layout_<cutlass::layout::TensorNDHWC>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kTensorNC32HW32:
|
|
return construct_layout_<cutlass::layout::TensorNCxHWx<32>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kTensorNC64HW64:
|
|
return construct_layout_<cutlass::layout::TensorNCxHWx<64>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kTensorC32RSK32:
|
|
return construct_layout_<cutlass::layout::TensorCxRSKx<32>>(bytes, layout_id, extent, stride);
|
|
|
|
case library::LayoutTypeID::kTensorC64RSK64:
|
|
return construct_layout_<cutlass::layout::TensorCxRSKx<64>>(bytes, layout_id, extent, stride);
|
|
|
|
default: break;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
DeviceAllocation::DeviceAllocation():
|
|
type_(library::NumericTypeID::kInvalid),
|
|
batch_stride_(0),
|
|
capacity_(0),
|
|
pointer_(nullptr),
|
|
layout_(library::LayoutTypeID::kUnknown),
|
|
batch_count_(1) {
|
|
|
|
}
|
|
|
|
DeviceAllocation::DeviceAllocation(
|
|
library::NumericTypeID type,
|
|
size_t capacity
|
|
):
|
|
type_(type), batch_stride_(capacity), capacity_(capacity), pointer_(nullptr),
|
|
layout_(library::LayoutTypeID::kUnknown), batch_count_(1) {
|
|
|
|
cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity));
|
|
|
|
if (result != cudaSuccess) {
|
|
type_ = library::NumericTypeID::kInvalid;
|
|
capacity_ = 0;
|
|
pointer_ = nullptr;
|
|
throw std::bad_alloc();
|
|
}
|
|
}
|
|
|
|
DeviceAllocation::DeviceAllocation(
|
|
library::NumericTypeID type,
|
|
library::LayoutTypeID layout_id,
|
|
std::vector<int> const &extent,
|
|
std::vector<int64_t> const &stride,
|
|
int batch_count
|
|
):
|
|
type_(type), batch_stride_(size_t(0)), capacity_(size_t(0)), pointer_(nullptr), batch_count_(1) {
|
|
|
|
reset(type, layout_id, extent, stride, batch_count);
|
|
}
|
|
|
|
DeviceAllocation::~DeviceAllocation() {
|
|
if (pointer_) {
|
|
cudaFree(pointer_);
|
|
}
|
|
}
|
|
|
|
DeviceAllocation &DeviceAllocation::reset() {
|
|
if (pointer_) {
|
|
cudaFree(pointer_);
|
|
}
|
|
|
|
type_ = library::NumericTypeID::kInvalid;
|
|
batch_stride_ = 0;
|
|
capacity_ = 0;
|
|
pointer_ = nullptr;
|
|
layout_ = library::LayoutTypeID::kUnknown;
|
|
stride_.clear();
|
|
extent_.clear();
|
|
tensor_ref_buffer_.clear();
|
|
batch_count_ = 1;
|
|
|
|
return *this;
|
|
}
|
|
|
|
DeviceAllocation &DeviceAllocation::reset(library::NumericTypeID type, size_t capacity) {
|
|
|
|
reset();
|
|
|
|
type_ = type;
|
|
batch_stride_ = capacity;
|
|
capacity_ = capacity;
|
|
|
|
cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type_, capacity_));
|
|
if (result != cudaSuccess) {
|
|
throw std::bad_alloc();
|
|
}
|
|
|
|
layout_ = library::LayoutTypeID::kUnknown;
|
|
stride_.clear();
|
|
extent_.clear();
|
|
batch_count_ = 1;
|
|
|
|
tensor_ref_buffer_.resize(sizeof(pointer_), 0);
|
|
std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_));
|
|
|
|
return *this;
|
|
}
|
|
|
|
/// Allocates memory for a given layout and tensor
|
|
DeviceAllocation &DeviceAllocation::reset(
|
|
library::NumericTypeID type,
|
|
library::LayoutTypeID layout_id,
|
|
std::vector<int> const &extent,
|
|
std::vector<int64_t> const &stride,
|
|
int batch_count) {
|
|
|
|
reset();
|
|
|
|
tensor_ref_buffer_.resize(sizeof(pointer_) + (sizeof(int64_t) * library::get_layout_stride_rank(layout_id)), 0);
|
|
|
|
type_ = type;
|
|
|
|
layout_ = layout_id;
|
|
stride_ = stride;
|
|
extent_ = extent;
|
|
batch_count_ = batch_count;
|
|
|
|
batch_stride_ = construct_layout(
|
|
tensor_ref_buffer_.data() + sizeof(pointer_),
|
|
layout_id,
|
|
extent,
|
|
stride_);
|
|
|
|
capacity_ = batch_stride_ * batch_count_;
|
|
|
|
cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity_));
|
|
if (result != cudaSuccess) {
|
|
throw std::bad_alloc();
|
|
}
|
|
|
|
std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_));
|
|
|
|
return *this;
|
|
}
|
|
|
|
bool DeviceAllocation::good() const {
|
|
return (capacity_ && pointer_);
|
|
}
|
|
|
|
library::NumericTypeID DeviceAllocation::type() const {
|
|
return type_;
|
|
}
|
|
|
|
void *DeviceAllocation::data() const {
|
|
return pointer_;
|
|
}
|
|
|
|
void *DeviceAllocation::batch_data(int batch_idx) const {
|
|
return static_cast<char *>(data()) + batch_stride_bytes() * batch_idx;
|
|
}
|
|
|
|
library::LayoutTypeID DeviceAllocation::layout() const {
|
|
return layout_;
|
|
}
|
|
|
|
std::vector<int64_t> const & DeviceAllocation::stride() const {
|
|
return stride_;
|
|
}
|
|
|
|
/// Gets the extent vector
|
|
std::vector<int> const & DeviceAllocation::extent() const {
|
|
return extent_;
|
|
}
|
|
|
|
/// Gets the number of adjacent tensors in memory
|
|
int DeviceAllocation::batch_count() const {
|
|
return batch_count_;
|
|
}
|
|
|
|
/// Gets the stride (in units of elements) beteween items
|
|
int64_t DeviceAllocation::batch_stride() const {
|
|
return batch_stride_;
|
|
}
|
|
|
|
/// Gets the stride (in units of bytes) beteween items
|
|
int64_t DeviceAllocation::batch_stride_bytes() const {
|
|
return bytes(type_, batch_stride_);
|
|
}
|
|
|
|
size_t DeviceAllocation::capacity() const {
|
|
return capacity_;
|
|
}
|
|
|
|
size_t DeviceAllocation::bytes() const {
|
|
return bytes(type_, capacity_);
|
|
}
|
|
|
|
/// Copies from an equivalent-sized tensor in device memory
|
|
void DeviceAllocation::copy_from_device(void const *ptr) {
|
|
cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyDeviceToDevice);
|
|
if (result != cudaSuccess) {
|
|
throw std::runtime_error("Failed device-to-device copy");
|
|
}
|
|
}
|
|
|
|
/// Copies from an equivalent-sized tensor in device memory
|
|
void DeviceAllocation::copy_from_host(void const *ptr) {
|
|
cudaError_t result = cudaMemcpy(data(), ptr, bytes(), cudaMemcpyHostToDevice);
|
|
if (result != cudaSuccess) {
|
|
throw std::runtime_error("Failed device-to-device copy");
|
|
}
|
|
}
|
|
|
|
/// Copies from an equivalent-sized tensor in device memory
|
|
void DeviceAllocation::copy_to_host(void *ptr) {
|
|
cudaError_t result = cudaMemcpy(ptr, data(), bytes(), cudaMemcpyDeviceToHost);
|
|
if (result != cudaSuccess) {
|
|
throw std::runtime_error("Failed device-to-device copy");
|
|
}
|
|
}
|
|
|
|
void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
|
|
if (!good()) {
|
|
throw std::runtime_error("Attempting to initialize invalid allocation.");
|
|
}
|
|
|
|
// Instantiate calls to CURAND here. This file takes a long time to compile for
|
|
// this reason.
|
|
|
|
switch (type_) {
|
|
case library::NumericTypeID::kF16:
|
|
cutlass::reference::device::BlockFillRandom<cutlass::half_t>(
|
|
reinterpret_cast<cutlass::half_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kBF16:
|
|
cutlass::reference::device::BlockFillRandom<cutlass::bfloat16_t>(
|
|
reinterpret_cast<cutlass::bfloat16_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kTF32:
|
|
cutlass::reference::device::BlockFillRandom<cutlass::tfloat32_t>(
|
|
reinterpret_cast<cutlass::tfloat32_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kF32:
|
|
cutlass::reference::device::BlockFillRandom<float>(
|
|
reinterpret_cast<float *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCBF16:
|
|
cutlass::reference::device::BlockFillRandom<complex<bfloat16_t>>(
|
|
reinterpret_cast<complex<bfloat16_t> *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCTF32:
|
|
cutlass::reference::device::BlockFillRandom<cutlass::complex<cutlass::tfloat32_t>>(
|
|
reinterpret_cast<cutlass::complex<cutlass::tfloat32_t> *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCF32:
|
|
cutlass::reference::device::BlockFillRandom<cutlass::complex<float>>(
|
|
reinterpret_cast<cutlass::complex<float> *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kF64:
|
|
cutlass::reference::device::BlockFillRandom<double>(
|
|
reinterpret_cast<double *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCF64:
|
|
cutlass::reference::device::BlockFillRandom<complex<double>>(
|
|
reinterpret_cast<complex<double> *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS2:
|
|
cutlass::reference::device::BlockFillRandom<int2b_t>(
|
|
reinterpret_cast<int2b_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS4:
|
|
cutlass::reference::device::BlockFillRandom<int4b_t>(
|
|
reinterpret_cast<int4b_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS8:
|
|
cutlass::reference::device::BlockFillRandom<int8_t>(
|
|
reinterpret_cast<int8_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS16:
|
|
cutlass::reference::device::BlockFillRandom<int16_t>(
|
|
reinterpret_cast<int16_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS32:
|
|
cutlass::reference::device::BlockFillRandom<int32_t>(
|
|
reinterpret_cast<int32_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS64:
|
|
cutlass::reference::device::BlockFillRandom<int64_t>(
|
|
reinterpret_cast<int64_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kB1:
|
|
cutlass::reference::device::BlockFillRandom<uint1b_t>(
|
|
reinterpret_cast<uint1b_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU2:
|
|
cutlass::reference::device::BlockFillRandom<uint2b_t>(
|
|
reinterpret_cast<uint2b_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU4:
|
|
cutlass::reference::device::BlockFillRandom<uint4b_t>(
|
|
reinterpret_cast<uint4b_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU8:
|
|
cutlass::reference::device::BlockFillRandom<uint8_t>(
|
|
reinterpret_cast<uint8_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU16:
|
|
cutlass::reference::device::BlockFillRandom<uint16_t>(
|
|
reinterpret_cast<uint16_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU32:
|
|
cutlass::reference::device::BlockFillRandom<uint32_t>(
|
|
reinterpret_cast<uint32_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU64:
|
|
cutlass::reference::device::BlockFillRandom<uint64_t>(
|
|
reinterpret_cast<uint64_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
default: break;
|
|
}
|
|
}
|
|
|
|
void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
|
|
if (!good()) {
|
|
throw std::runtime_error("Attempting to initialize invalid allocation.");
|
|
}
|
|
|
|
std::vector<uint8_t> host_data(bytes());
|
|
|
|
switch (type_) {
|
|
case library::NumericTypeID::kF16:
|
|
cutlass::reference::host::BlockFillRandom<cutlass::half_t>(
|
|
reinterpret_cast<cutlass::half_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kBF16:
|
|
cutlass::reference::host::BlockFillRandom<cutlass::bfloat16_t>(
|
|
reinterpret_cast<cutlass::bfloat16_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kTF32:
|
|
cutlass::reference::host::BlockFillRandom<cutlass::tfloat32_t>(
|
|
reinterpret_cast<cutlass::tfloat32_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kF32:
|
|
cutlass::reference::host::BlockFillRandom<float>(
|
|
reinterpret_cast<float *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCF16:
|
|
cutlass::reference::host::BlockFillRandom<cutlass::complex<cutlass::half_t>>(
|
|
reinterpret_cast<cutlass::complex<cutlass::half_t> *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCBF16:
|
|
cutlass::reference::host::BlockFillRandom<cutlass::complex<cutlass::bfloat16_t>>(
|
|
reinterpret_cast<cutlass::complex<cutlass::bfloat16_t> *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCTF32:
|
|
cutlass::reference::host::BlockFillRandom<cutlass::complex<cutlass::tfloat32_t>>(
|
|
reinterpret_cast<cutlass::complex<cutlass::tfloat32_t> *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCF32:
|
|
cutlass::reference::host::BlockFillRandom<cutlass::complex<float>>(
|
|
reinterpret_cast<cutlass::complex<float> *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kF64:
|
|
cutlass::reference::host::BlockFillRandom<double>(
|
|
reinterpret_cast<double *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kCF64:
|
|
cutlass::reference::host::BlockFillRandom<cutlass::complex<double>>(
|
|
reinterpret_cast<cutlass::complex<double> *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS2:
|
|
cutlass::reference::host::BlockFillRandom<int2b_t>(
|
|
reinterpret_cast<int2b_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS4:
|
|
cutlass::reference::host::BlockFillRandom<int4b_t>(
|
|
reinterpret_cast<int4b_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS8:
|
|
cutlass::reference::host::BlockFillRandom<int8_t>(
|
|
reinterpret_cast<int8_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS16:
|
|
cutlass::reference::host::BlockFillRandom<int16_t>(
|
|
reinterpret_cast<int16_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS32:
|
|
cutlass::reference::host::BlockFillRandom<int32_t>(
|
|
reinterpret_cast<int32_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS64:
|
|
cutlass::reference::host::BlockFillRandom<int64_t>(
|
|
reinterpret_cast<int64_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kB1:
|
|
cutlass::reference::host::BlockFillRandom<uint1b_t>(
|
|
reinterpret_cast<uint1b_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU2:
|
|
cutlass::reference::host::BlockFillRandom<uint2b_t>(
|
|
reinterpret_cast<uint2b_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU4:
|
|
cutlass::reference::host::BlockFillRandom<uint4b_t>(
|
|
reinterpret_cast<uint4b_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU8:
|
|
cutlass::reference::host::BlockFillRandom<uint8_t>(
|
|
reinterpret_cast<uint8_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU16:
|
|
cutlass::reference::host::BlockFillRandom<uint16_t>(
|
|
reinterpret_cast<uint16_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU32:
|
|
cutlass::reference::host::BlockFillRandom<uint32_t>(
|
|
reinterpret_cast<uint32_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU64:
|
|
cutlass::reference::host::BlockFillRandom<uint64_t>(
|
|
reinterpret_cast<uint64_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
dist
|
|
);
|
|
break;
|
|
default: break;
|
|
}
|
|
|
|
copy_from_host(host_data.data());
|
|
}
|
|
|
|
void DeviceAllocation::initialize_random_sparsemeta_device(int seed, int MetaSizeInBits) {
|
|
if (!good()) {
|
|
throw std::runtime_error("Attempting to initialize invalid allocation.");
|
|
}
|
|
|
|
// Instantiate calls to CURAND here. This file takes a long time to compile for
|
|
// this reason.
|
|
|
|
switch (type_) {
|
|
case library::NumericTypeID::kU16:
|
|
cutlass::reference::device::BlockFillRandomSparseMeta<uint16_t>(
|
|
reinterpret_cast<uint16_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
MetaSizeInBits
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kU32:
|
|
cutlass::reference::device::BlockFillRandomSparseMeta<uint32_t>(
|
|
reinterpret_cast<uint32_t *>(pointer_),
|
|
capacity_,
|
|
seed,
|
|
MetaSizeInBits
|
|
);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
void DeviceAllocation::initialize_random_sparsemeta_host(int seed, int MetaSizeInBits) {
|
|
if (!good()) {
|
|
throw std::runtime_error("Attempting to initialize invalid allocation.");
|
|
}
|
|
|
|
std::vector<uint8_t> host_data(bytes());
|
|
|
|
switch (type_) {
|
|
case library::NumericTypeID::kS16:
|
|
cutlass::reference::host::BlockFillRandomSparseMeta<uint16_t>(
|
|
reinterpret_cast<uint16_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
MetaSizeInBits
|
|
);
|
|
break;
|
|
case library::NumericTypeID::kS32:
|
|
cutlass::reference::host::BlockFillRandomSparseMeta<uint32_t>(
|
|
reinterpret_cast<uint32_t *>(host_data.data()),
|
|
capacity_,
|
|
seed,
|
|
MetaSizeInBits
|
|
);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
copy_from_host(host_data.data());
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Returns true if two blocks have exactly the same value
|
|
bool DeviceAllocation::block_compare_equal(
|
|
library::NumericTypeID numeric_type,
|
|
void const *ptr_A,
|
|
void const *ptr_B,
|
|
size_t capacity) {
|
|
|
|
switch (numeric_type) {
|
|
case library::NumericTypeID::kF16:
|
|
return reference::device::BlockCompareEqual<half_t>(
|
|
reinterpret_cast<half_t const *>(ptr_A),
|
|
reinterpret_cast<half_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kBF16:
|
|
return reference::device::BlockCompareEqual<bfloat16_t>(
|
|
reinterpret_cast<bfloat16_t const *>(ptr_A),
|
|
reinterpret_cast<bfloat16_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kTF32:
|
|
return reference::device::BlockCompareEqual<tfloat32_t>(
|
|
reinterpret_cast<tfloat32_t const *>(ptr_A),
|
|
reinterpret_cast<tfloat32_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kF32:
|
|
return reference::device::BlockCompareEqual<float>(
|
|
reinterpret_cast<float const *>(ptr_A),
|
|
reinterpret_cast<float const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kCF32:
|
|
return reference::device::BlockCompareEqual<cutlass::complex<float> >(
|
|
reinterpret_cast<complex<float> const *>(ptr_A),
|
|
reinterpret_cast<complex<float> const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kCF16:
|
|
return reference::device::BlockCompareEqual<complex<half_t>>(
|
|
reinterpret_cast<complex<half_t> const *>(ptr_A),
|
|
reinterpret_cast<complex<half_t> const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kCBF16:
|
|
return reference::device::BlockCompareEqual<complex<bfloat16_t>>(
|
|
reinterpret_cast<complex<bfloat16_t> const *>(ptr_A),
|
|
reinterpret_cast<complex<bfloat16_t> const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kCTF32:
|
|
return reference::device::BlockCompareEqual<complex<tfloat32_t>>(
|
|
reinterpret_cast<complex<tfloat32_t> const *>(ptr_A),
|
|
reinterpret_cast<complex<tfloat32_t> const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kF64:
|
|
return reference::device::BlockCompareEqual<double>(
|
|
reinterpret_cast<double const *>(ptr_A),
|
|
reinterpret_cast<double const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kCF64:
|
|
return reference::device::BlockCompareEqual<complex<double>>(
|
|
reinterpret_cast<complex<double> const *>(ptr_A),
|
|
reinterpret_cast<complex<double> const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kS2:
|
|
return reference::device::BlockCompareEqual<int2b_t>(
|
|
reinterpret_cast<int2b_t const *>(ptr_A),
|
|
reinterpret_cast<int2b_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kS4:
|
|
return reference::device::BlockCompareEqual<int4b_t>(
|
|
reinterpret_cast<int4b_t const *>(ptr_A),
|
|
reinterpret_cast<int4b_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kS8:
|
|
return reference::device::BlockCompareEqual<int8_t>(
|
|
reinterpret_cast<int8_t const *>(ptr_A),
|
|
reinterpret_cast<int8_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kS16:
|
|
return reference::device::BlockCompareEqual<int16_t>(
|
|
reinterpret_cast<int16_t const *>(ptr_A),
|
|
reinterpret_cast<int16_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kS32:
|
|
return reference::device::BlockCompareEqual<int32_t>(
|
|
reinterpret_cast<int32_t const *>(ptr_A),
|
|
reinterpret_cast<int32_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kS64:
|
|
return reference::device::BlockCompareEqual<int64_t>(
|
|
reinterpret_cast<int64_t const *>(ptr_A),
|
|
reinterpret_cast<int64_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kB1:
|
|
return reference::device::BlockCompareEqual<uint1b_t>(
|
|
reinterpret_cast<uint1b_t const *>(ptr_A),
|
|
reinterpret_cast<uint1b_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kU2:
|
|
return reference::device::BlockCompareEqual<uint2b_t>(
|
|
reinterpret_cast<uint2b_t const *>(ptr_A),
|
|
reinterpret_cast<uint2b_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kU4:
|
|
return reference::device::BlockCompareEqual<uint4b_t>(
|
|
reinterpret_cast<uint4b_t const *>(ptr_A),
|
|
reinterpret_cast<uint4b_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kU8:
|
|
return reference::device::BlockCompareEqual<uint8_t>(
|
|
reinterpret_cast<uint8_t const *>(ptr_A),
|
|
reinterpret_cast<uint8_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kU16:
|
|
return reference::device::BlockCompareEqual<uint16_t>(
|
|
reinterpret_cast<uint16_t const *>(ptr_A),
|
|
reinterpret_cast<uint16_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kU32:
|
|
return reference::device::BlockCompareEqual<uint32_t>(
|
|
reinterpret_cast<uint32_t const *>(ptr_A),
|
|
reinterpret_cast<uint32_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kU64:
|
|
return reference::device::BlockCompareEqual<uint64_t>(
|
|
reinterpret_cast<uint64_t const *>(ptr_A),
|
|
reinterpret_cast<uint64_t const *>(ptr_B),
|
|
capacity);
|
|
|
|
default:
|
|
throw std::runtime_error("Unsupported numeric type");
|
|
}
|
|
}
|
|
|
|
/// Returns true if two blocks have approximately the same value
|
|
bool DeviceAllocation::block_compare_relatively_equal(
|
|
library::NumericTypeID numeric_type,
|
|
void const *ptr_A,
|
|
void const *ptr_B,
|
|
size_t capacity,
|
|
double epsilon,
|
|
double nonzero_floor) {
|
|
|
|
switch (numeric_type) {
|
|
case library::NumericTypeID::kF16:
|
|
return reference::device::BlockCompareRelativelyEqual<half_t>(
|
|
reinterpret_cast<half_t const *>(ptr_A),
|
|
reinterpret_cast<half_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<half_t>(epsilon),
|
|
static_cast<half_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kBF16:
|
|
return reference::device::BlockCompareRelativelyEqual<bfloat16_t>(
|
|
reinterpret_cast<bfloat16_t const *>(ptr_A),
|
|
reinterpret_cast<bfloat16_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<bfloat16_t>(epsilon),
|
|
static_cast<bfloat16_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kTF32:
|
|
return reference::device::BlockCompareRelativelyEqual<tfloat32_t>(
|
|
reinterpret_cast<tfloat32_t const *>(ptr_A),
|
|
reinterpret_cast<tfloat32_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<tfloat32_t>(epsilon),
|
|
static_cast<tfloat32_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kF32:
|
|
return reference::device::BlockCompareRelativelyEqual<float>(
|
|
reinterpret_cast<float const *>(ptr_A),
|
|
reinterpret_cast<float const *>(ptr_B),
|
|
capacity,
|
|
static_cast<float>(epsilon),
|
|
static_cast<float>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kF64:
|
|
return reference::device::BlockCompareRelativelyEqual<double>(
|
|
reinterpret_cast<double const *>(ptr_A),
|
|
reinterpret_cast<double const *>(ptr_B),
|
|
capacity,
|
|
static_cast<double>(epsilon),
|
|
static_cast<double>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kS2:
|
|
return reference::device::BlockCompareRelativelyEqual<int2b_t>(
|
|
reinterpret_cast<int2b_t const *>(ptr_A),
|
|
reinterpret_cast<int2b_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<int2b_t>(epsilon),
|
|
static_cast<int2b_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kS4:
|
|
return reference::device::BlockCompareRelativelyEqual<int4b_t>(
|
|
reinterpret_cast<int4b_t const *>(ptr_A),
|
|
reinterpret_cast<int4b_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<int4b_t>(epsilon),
|
|
static_cast<int4b_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kS8:
|
|
return reference::device::BlockCompareRelativelyEqual<int8_t>(
|
|
reinterpret_cast<int8_t const *>(ptr_A),
|
|
reinterpret_cast<int8_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<int8_t>(epsilon),
|
|
static_cast<int8_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kS16:
|
|
return reference::device::BlockCompareRelativelyEqual<int16_t>(
|
|
reinterpret_cast<int16_t const *>(ptr_A),
|
|
reinterpret_cast<int16_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<int16_t>(epsilon),
|
|
static_cast<int16_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kS32:
|
|
return reference::device::BlockCompareRelativelyEqual<int32_t>(
|
|
reinterpret_cast<int32_t const *>(ptr_A),
|
|
reinterpret_cast<int32_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<int32_t>(epsilon),
|
|
static_cast<int32_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kS64:
|
|
return reference::device::BlockCompareRelativelyEqual<int64_t>(
|
|
reinterpret_cast<int64_t const *>(ptr_A),
|
|
reinterpret_cast<int64_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<int64_t>(epsilon),
|
|
static_cast<int64_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kB1:
|
|
return reference::device::BlockCompareRelativelyEqual<uint1b_t>(
|
|
reinterpret_cast<uint1b_t const *>(ptr_A),
|
|
reinterpret_cast<uint1b_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<uint1b_t>(epsilon),
|
|
static_cast<uint1b_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kU2:
|
|
return reference::device::BlockCompareRelativelyEqual<uint2b_t>(
|
|
reinterpret_cast<uint2b_t const *>(ptr_A),
|
|
reinterpret_cast<uint2b_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<uint2b_t>(epsilon),
|
|
static_cast<uint2b_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kU4:
|
|
return reference::device::BlockCompareRelativelyEqual<uint4b_t>(
|
|
reinterpret_cast<uint4b_t const *>(ptr_A),
|
|
reinterpret_cast<uint4b_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<uint4b_t>(epsilon),
|
|
static_cast<uint4b_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kU8:
|
|
return reference::device::BlockCompareRelativelyEqual<uint8_t>(
|
|
reinterpret_cast<uint8_t const *>(ptr_A),
|
|
reinterpret_cast<uint8_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<uint8_t>(epsilon),
|
|
static_cast<uint8_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kU16:
|
|
return reference::device::BlockCompareRelativelyEqual<uint16_t>(
|
|
reinterpret_cast<uint16_t const *>(ptr_A),
|
|
reinterpret_cast<uint16_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<uint16_t>(epsilon),
|
|
static_cast<uint16_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kU32:
|
|
return reference::device::BlockCompareRelativelyEqual<uint32_t>(
|
|
reinterpret_cast<uint32_t const *>(ptr_A),
|
|
reinterpret_cast<uint32_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<uint32_t>(epsilon),
|
|
static_cast<uint32_t>(nonzero_floor));
|
|
|
|
case library::NumericTypeID::kU64:
|
|
return reference::device::BlockCompareRelativelyEqual<uint64_t>(
|
|
reinterpret_cast<uint64_t const *>(ptr_A),
|
|
reinterpret_cast<uint64_t const *>(ptr_B),
|
|
capacity,
|
|
static_cast<uint64_t>(epsilon),
|
|
static_cast<uint64_t>(nonzero_floor));
|
|
|
|
// No relatively equal comparison for complex numbers.
|
|
//
|
|
// As a simplification, we can require bitwise equality. This avoids false positives.
|
|
// (i.e. "pass" really means passing. "Fail" may not actually mean failure given appropriate epsilon.)
|
|
//
|
|
case library::NumericTypeID::kCF16:
|
|
return reference::device::BlockCompareEqual<cutlass::complex<half_t> >(
|
|
reinterpret_cast<complex<half_t> const *>(ptr_A),
|
|
reinterpret_cast<complex<half_t> const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kCF32:
|
|
return reference::device::BlockCompareEqual<cutlass::complex<float> >(
|
|
reinterpret_cast<complex<float> const *>(ptr_A),
|
|
reinterpret_cast<complex<float> const *>(ptr_B),
|
|
capacity);
|
|
|
|
case library::NumericTypeID::kCF64:
|
|
return reference::device::BlockCompareEqual<cutlass::complex<double> >(
|
|
reinterpret_cast<complex<double> const *>(ptr_A),
|
|
reinterpret_cast<complex<double> const *>(ptr_B),
|
|
capacity);
|
|
|
|
default:
|
|
{
|
|
throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type));
|
|
}
|
|
}
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Permits copying dynamic vectors into static-length vectors
|
|
template <typename TensorCoord, int Rank>
|
|
struct vector_to_coord {
|
|
|
|
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
|
|
|
|
coord[Rank - 1] = vec.at(Rank - 1);
|
|
|
|
if (Rank > 1) {
|
|
vector_to_coord<TensorCoord, Rank - 1>(coord, vec);
|
|
}
|
|
}
|
|
|
|
vector_to_coord(TensorCoord &coord, std::vector<int64_t> const &vec) {
|
|
|
|
coord[Rank - 1] = (int)vec.at(Rank - 1);
|
|
|
|
if (Rank > 1) {
|
|
vector_to_coord<TensorCoord, Rank - 1>(coord, vec);
|
|
}
|
|
}
|
|
};
|
|
|
|
/// Permits copying dynamic vectors into static-length vectors
|
|
template <typename TensorCoord>
|
|
struct vector_to_coord<TensorCoord, 1> {
|
|
|
|
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
|
|
|
|
coord[0] = vec.at(0);
|
|
}
|
|
|
|
vector_to_coord(TensorCoord &coord, std::vector<int64_t> const &vec) {
|
|
|
|
coord[0] = (int)vec.at(0);
|
|
}
|
|
};
|
|
|
|
/// Permits copying dynamic vectors into static-length vectors
|
|
template <typename TensorCoord>
|
|
struct vector_to_coord<TensorCoord, 0> {
|
|
|
|
vector_to_coord(TensorCoord &coord, std::vector<int> const &vec) {
|
|
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Element, typename Layout>
|
|
static void write_tensor_csv_static_tensor_view(
|
|
std::ostream &out,
|
|
DeviceAllocation &allocation) {
|
|
|
|
Coord<Layout::kRank> extent;
|
|
Coord<Layout::kStrideRank, typename Layout::Stride::Index> stride;
|
|
|
|
if (allocation.extent().size() != Layout::kRank) {
|
|
throw std::runtime_error("Allocation extent has invalid rank");
|
|
}
|
|
|
|
if (allocation.stride().size() != Layout::kStrideRank) {
|
|
throw std::runtime_error("Allocation stride has invalid rank");
|
|
}
|
|
|
|
vector_to_coord<Coord<Layout::kRank>, Layout::kRank>(extent, allocation.extent());
|
|
vector_to_coord<Coord<Layout::kStrideRank, typename Layout::Stride::Index>,
|
|
Layout::kStrideRank>(stride, allocation.stride());
|
|
|
|
Layout layout(stride);
|
|
HostTensor<Element, Layout> host_tensor(extent, layout, false);
|
|
|
|
if (host_tensor.capacity() != allocation.batch_stride()) {
|
|
throw std::runtime_error("Unexpected capacity to equal.");
|
|
}
|
|
|
|
host_tensor.copy_in_device_to_host(
|
|
static_cast<Element const *>(allocation.data()),
|
|
allocation.batch_stride());
|
|
|
|
TensorViewWrite(out, host_tensor.host_view());
|
|
|
|
out << "\n\n";
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
static void write_tensor_csv_static_type(
|
|
std::ostream &out,
|
|
DeviceAllocation &allocation) {
|
|
|
|
switch (allocation.layout()) {
|
|
case library::LayoutTypeID::kRowMajor:
|
|
write_tensor_csv_static_tensor_view<T, layout::RowMajor>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajor:
|
|
write_tensor_csv_static_tensor_view<T, layout::ColumnMajor>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK2:
|
|
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<2>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK2:
|
|
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<2>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK4:
|
|
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<4>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK4:
|
|
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<4>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK16:
|
|
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<16>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK16:
|
|
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<16>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK32:
|
|
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<32>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK32:
|
|
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<32>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kRowMajorInterleavedK64:
|
|
write_tensor_csv_static_tensor_view<T, layout::RowMajorInterleaved<64>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajorInterleavedK64:
|
|
write_tensor_csv_static_tensor_view<T, layout::ColumnMajorInterleaved<64>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNHWC:
|
|
write_tensor_csv_static_tensor_view<T, layout::TensorNHWC>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNDHWC:
|
|
write_tensor_csv_static_tensor_view<T, layout::TensorNDHWC>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNC32HW32:
|
|
write_tensor_csv_static_tensor_view<T, layout::TensorNCxHWx<32>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNC64HW64:
|
|
write_tensor_csv_static_tensor_view<T, layout::TensorNCxHWx<64>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kTensorC32RSK32:
|
|
write_tensor_csv_static_tensor_view<T, layout::TensorCxRSKx<32>>(out, allocation);
|
|
break;
|
|
case library::LayoutTypeID::kTensorC64RSK64:
|
|
write_tensor_csv_static_tensor_view<T, layout::TensorCxRSKx<64>>(out, allocation);
|
|
break;
|
|
default:
|
|
throw std::runtime_error("Unhandled layout");
|
|
}
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Writes a tensor to csv
|
|
void DeviceAllocation::write_tensor_csv(
|
|
std::ostream &out) {
|
|
|
|
switch (this->type()) {
|
|
case library::NumericTypeID::kF16:
|
|
write_tensor_csv_static_type<half_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kBF16:
|
|
write_tensor_csv_static_type<bfloat16_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kTF32:
|
|
write_tensor_csv_static_type<tfloat32_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kF32:
|
|
write_tensor_csv_static_type<float>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kF64:
|
|
write_tensor_csv_static_type<double>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kS2:
|
|
write_tensor_csv_static_type<int2b_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kS4:
|
|
write_tensor_csv_static_type<int4b_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kS8:
|
|
write_tensor_csv_static_type<int8_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kS16:
|
|
write_tensor_csv_static_type<int16_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kS32:
|
|
write_tensor_csv_static_type<int32_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kS64:
|
|
write_tensor_csv_static_type<int64_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kB1:
|
|
write_tensor_csv_static_type<uint1b_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kU2:
|
|
write_tensor_csv_static_type<uint2b_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kU4:
|
|
write_tensor_csv_static_type<uint4b_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kU8:
|
|
write_tensor_csv_static_type<uint8_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kU16:
|
|
write_tensor_csv_static_type<uint16_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kU32:
|
|
write_tensor_csv_static_type<uint32_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kU64:
|
|
write_tensor_csv_static_type<uint64_t>(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kCF16:
|
|
write_tensor_csv_static_type<cutlass::complex<half_t> >(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kCF32:
|
|
write_tensor_csv_static_type<cutlass::complex<float> >(out, *this);
|
|
break;
|
|
|
|
case library::NumericTypeID::kCF64:
|
|
write_tensor_csv_static_type<cutlass::complex<double> >(out, *this);
|
|
break;
|
|
|
|
default:
|
|
throw std::runtime_error("Unsupported numeric type");
|
|
}
|
|
}
|
|
|
|
template <typename Element, typename Layout>
|
|
static void tensor_fill_tensor_view(DeviceAllocation &allocation, Element val = Element()) {
|
|
Coord<Layout::kRank> extent;
|
|
Coord<Layout::kStrideRank, typename Layout::LongIndex> stride;
|
|
|
|
if (allocation.extent().size() != Layout::kRank) {
|
|
throw std::runtime_error("Allocation extent has invalid rank");
|
|
}
|
|
|
|
if (allocation.stride().size() != Layout::kStrideRank) {
|
|
throw std::runtime_error("Allocation stride has invalid rank");
|
|
}
|
|
|
|
vector_to_coord<Coord<Layout::kRank>, Layout::kRank>(extent, allocation.extent());
|
|
vector_to_coord<Coord<Layout::kStrideRank, typename Layout::LongIndex>,
|
|
Layout::kStrideRank>(stride, allocation.stride());
|
|
|
|
TensorView<Element, Layout> view(
|
|
static_cast<Element *>(allocation.data()),
|
|
Layout(stride),
|
|
extent
|
|
);
|
|
|
|
|
|
cutlass::reference::device::TensorFill<Element, Layout>(
|
|
view,
|
|
val
|
|
);
|
|
}
|
|
|
|
template <typename Element>
|
|
static void tensor_fill(DeviceAllocation &allocation, Element val = Element()) {
|
|
switch (allocation.layout()) {
|
|
case library::LayoutTypeID::kRowMajor:
|
|
tensor_fill_tensor_view<Element, layout::RowMajor>(allocation, val);
|
|
break;
|
|
case library::LayoutTypeID::kColumnMajor:
|
|
tensor_fill_tensor_view<Element, layout::ColumnMajor>(allocation, val);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNHWC:
|
|
tensor_fill_tensor_view<Element, layout::TensorNHWC>(allocation, val);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNDHWC:
|
|
tensor_fill_tensor_view<Element, layout::TensorNDHWC>(allocation, val);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNC32HW32:
|
|
tensor_fill_tensor_view<Element, layout::TensorNCxHWx<32>>(allocation, val);
|
|
break;
|
|
case library::LayoutTypeID::kTensorNC64HW64:
|
|
tensor_fill_tensor_view<Element, layout::TensorNCxHWx<64>>(allocation, val);
|
|
break;
|
|
case library::LayoutTypeID::kTensorC32RSK32:
|
|
tensor_fill_tensor_view<Element, layout::TensorCxRSKx<32>>(allocation, val);
|
|
break;
|
|
case library::LayoutTypeID::kTensorC64RSK64:
|
|
tensor_fill_tensor_view<Element, layout::TensorCxRSKx<64>>(allocation, val);
|
|
break;
|
|
default:
|
|
throw std::runtime_error("Unsupported layout");
|
|
break;
|
|
}
|
|
}
|
|
|
|
/// Fills a tensor uniformly with a value (most frequently used to clear the tensor)
|
|
void DeviceAllocation::fill(double val = 0.0) {
|
|
|
|
switch (this->type()) {
|
|
case library::NumericTypeID::kF16:
|
|
tensor_fill<half_t>(*this, static_cast<half_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kBF16:
|
|
tensor_fill<bfloat16_t>(*this, static_cast<bfloat16_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kTF32:
|
|
tensor_fill<tfloat32_t>(*this, static_cast<tfloat32_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kF32:
|
|
tensor_fill<float>(*this, static_cast<float>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kF64:
|
|
tensor_fill<double>(*this, static_cast<double>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kS2:
|
|
tensor_fill<int2b_t>(*this, static_cast<int2b_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kS4:
|
|
tensor_fill<int4b_t>(*this, static_cast<int4b_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kS8:
|
|
tensor_fill<int8_t>(*this, static_cast<int8_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kS16:
|
|
tensor_fill<int16_t>(*this, static_cast<int16_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kS32:
|
|
tensor_fill<int32_t>(*this, static_cast<int32_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kS64:
|
|
tensor_fill<int64_t>(*this, static_cast<int64_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kB1:
|
|
tensor_fill<uint1b_t>(*this, static_cast<uint1b_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kU2:
|
|
tensor_fill<uint2b_t>(*this, static_cast<uint2b_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kU4:
|
|
tensor_fill<uint4b_t>(*this, static_cast<uint4b_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kU8:
|
|
tensor_fill<uint8_t>(*this, static_cast<uint8_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kU16:
|
|
tensor_fill<uint16_t>(*this, static_cast<uint16_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kU32:
|
|
tensor_fill<uint32_t>(*this, static_cast<uint32_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kU64:
|
|
tensor_fill<uint64_t>(*this, static_cast<uint64_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kCF16:
|
|
tensor_fill<cutlass::complex<half_t> >(*this, from_real<half_t>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kCF32:
|
|
tensor_fill<cutlass::complex<float> >(*this, from_real<float>(val));
|
|
break;
|
|
|
|
case library::NumericTypeID::kCF64:
|
|
tensor_fill<cutlass::complex<double> >(*this, from_real<double>(val));
|
|
break;
|
|
|
|
default:
|
|
throw std::runtime_error("Unsupported numeric type");
|
|
}
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace profiler
|
|
} // namespace cutlass
|