/*************************************************************************************************** * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used * to endorse or promote products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Helpers for printing cutlass/core objects */ #pragma once #include #include #include "cutlass/array.h" #include "cutlass/coord.h" #include "cutlass/numeric_types.h" #include "cutlass/matrix_shape.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" #include "cutlass/conv/convolution.h" #include "cutlass/conv/conv2d_problem_size.h" #include "cutlass/conv/conv3d_problem_size.h" /////////////////////////////////////////////////////////////////////////////////////////////////// /// Output operator for CUDA built-in dim3 type inline std::ostream &operator<<(std::ostream &out, dim3 d) { return out << d.x << ", " << d.y << ", " << d.z; } /// Output operator for CUDA built-in error type inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { return out << cudaGetErrorString(error); } /////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { /////////////////////////////////////////////////////////////////////////////////////////////////// // stream operators for cutlass namespace // /////////////////////////////////////////////////////////////////////////////////////////////////// template inline std::ostream& operator<<(std::ostream& out, Array const& v) { for (int i = 0; i < Rank; ++i) { out << (i ? ", " : "") << v[i]; } return out; } template inline std::ostream& operator<<(std::ostream& out, Coord const& coord) { for (int i = 0; i < Rank; ++i) { out << (i ? ", " : "") << coord[i]; } return out; } inline std::istream & operator>>(std::istream &stream, half_t &x) { float tmp; stream >> tmp; x = static_cast(tmp); return stream; } inline std::ostream & operator<<(std::ostream &out, half_t const &x) { return out << float(x); } inline std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) { return out << float(x); } inline std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) { return out << float(x); } /////////////////////////////////////////////////////////////////////////////////////////////////// /// Helper to enable formatted printing of CUTLASS scalar types to an ostream template struct ScalarIO { /// Value to print T value; /// Default ctor ScalarIO() { } /// Constructs from a value ScalarIO(T value): value(value) {} }; /////////////////////////////////////////////////////////////////////////////////////////////////// /// Default printing to ostream template inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { return out << scalar.value; } /// Printing to ostream of int8_t as integer rather than character template <> inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { return out << int(scalar.value); } /// Printing to ostream of uint8_t as integer rather than character template <> inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { return out << unsigned(scalar.value); } /// Default printing to ostream for MatrixShape template inline std::ostream & operator<<(std::ostream &out, MatrixShape const &matrix_shape) { out << "cutlass::MatrixShape::(kRow, kColumn) {" << cutlass::MatrixShape::kRow <<"," << cutlass::MatrixShape::kColumn <<"}"; return out; } /////////////////////////////////////////////////////////////////////////////////////////////////// // stream operators for cutlass::gemm namespace // /////////////////////////////////////////////////////////////////////////////////////////////////// namespace gemm { /// Default printing to ostream for GemmShape template inline std::ostream & operator<<(std::ostream &out, GemmShape const &gemm_shape) { out << "cutlass::gemm::GemmShape::(kM, kN, kK) {" << cutlass::gemm::GemmShape::kM <<"," << cutlass::gemm::GemmShape::kN <<"," << cutlass::gemm::GemmShape::kK << "}"; return out; } /// Default printing to ostream for GemmCoord inline std::ostream & operator<<(std::ostream &out, GemmCoord const &gemm_coord) { out << "cutlass::gemm::GemmCoord:: {" << gemm_coord.m() <<"," << gemm_coord.n() <<"," << gemm_coord.k() << "}"; return out; } } //namespace gemm /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// // stream operators for cutlass::layout namespace // /////////////////////////////////////////////////////////////////////////////////////////////////// namespace layout { /// Default printing to ostream for PitchLinearShape template < int Contiguous, int Strided> inline std::ostream & operator<<(std::ostream &out, PitchLinearShape const &pitch_linear_shape) { out << "cutlass::layout::PitchLinearShape::(kContiguous, kStrided) {" << cutlass::layout::PitchLinearShape::kContiguous <<"," << cutlass::layout::PitchLinearShape::kStrided <<"}"; return out; } } //namespace layout /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// // stream operators for cutlass::conv namespace // /////////////////////////////////////////////////////////////////////////////////////////////////// namespace conv { /// Default printing to ostream for Conv2dProblemSize inline std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) { out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl << "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl << "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl << "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl << "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl << "split_k_slices: (" << problem.split_k_slices << ")" << std::endl << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; return out; } /// Default printing to ostream for Conv3dProblemSize inline std::ostream& operator<<(std::ostream& out, Conv3dProblemSize const& problem) { out << "NDHWC: (" << problem.N << ", " << problem.D << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl << "KTRSC: (" << problem.K << ", " << problem.T << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl << "NZPQK: (" << problem.N << ", " << problem.Z << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl << "pad_d, pad_h, pad_w: (" << problem.pad_d << ", " << problem.pad_h << ", " << problem.pad_w << ")" << std::endl << "stride_d, stride_h, stride_w: (" << problem.stride_d << ", " << problem.stride_h << ", " << problem.stride_w << ")" << std::endl << "dilation_d, dilation_h, dilation_w: (" << problem.dilation_d << ", " << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl << "split_k_slices: (" << problem.split_k_slices << ") " << std::endl << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; return out; } } // namespace conv /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////////