/*************************************************************************************************** * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Define basic numeric operators with specializations for Array. SIMD-ize where possible. This is inspired by the Standard Library's header. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/complex.h" #include "cutlass/quaternion.h" #include "cutlass/array.h" #include "cutlass/half.h" namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// template struct absolute_value_op { CUTLASS_HOST_DEVICE T operator()(T lhs) const { return abs(lhs); } }; template struct plus { CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const { lhs += rhs; return lhs; } }; template struct minus { CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const { lhs -= rhs; return lhs; } }; template struct multiplies { CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const { lhs *= rhs; return lhs; } }; template struct multiplies> { CUTLASS_HOST_DEVICE Quaternion operator()(Quaternion lhs, Quaternion const &rhs) const { lhs = lhs * rhs; return lhs; } }; /// Squares with optional conversion template struct square { CUTLASS_HOST_DEVICE Output operator()(T lhs) const { multiplies mul_op; Output y = Output(lhs); return mul_op(y, y); } }; /// Returns the magnitude squared of an element. template struct magnitude_squared { CUTLASS_HOST_DEVICE Output operator()(T lhs) const { multiplies mul_op; Output y = Output(lhs); return mul_op(y, y); } }; /// Squares with optional conversion template struct magnitude_squared, Output> { CUTLASS_HOST_DEVICE Output operator()(complex lhs) const { multiplies mul_op; Output y_r = Output(lhs.real()); Output y_i = Output(lhs.imag()); return mul_op(y_r, y_r) + mul_op(y_i, y_i); } }; /// Squares with optional conversion template struct magnitude_squared, Output> { CUTLASS_HOST_DEVICE Output operator()(Quaternion lhs) const { multiplies mul_op; Output y_w = Output(lhs.w()); Output y_x = Output(lhs.x()); Output y_y = Output(lhs.y()); Output y_z = Output(lhs.z()); return mul_op(y_w, y_w) + mul_op(y_x, y_x) + mul_op(y_y, y_y) + \ mul_op(y_z, y_z); } }; /// Computes the square of a difference with optional conversion template struct square_difference { CUTLASS_HOST_DEVICE Output operator()(T lhs, T rhs) const { multiplies mul_op; Output y = Output(lhs) - Output(rhs); return mul_op(y, y); } }; /// Computes the square of a difference with optional conversion template struct magnitude_squared_difference { CUTLASS_HOST_DEVICE Output operator()(T lhs, T rhs) const { multiplies mul_op; Output y = Output(lhs) - Output(rhs); return mul_op(y, y); } }; /// Computes the square of a difference with optional conversion template struct magnitude_squared_difference, Output> { CUTLASS_HOST_DEVICE Output operator()(complex lhs, complex rhs) const { multiplies mul_op; Output y_r = Output(lhs.real()) - Output(rhs.real()); Output y_i = Output(lhs.imag()) - Output(rhs.imag()); return mul_op(y_r, y_r) + mul_op(y_i, y_i); } }; template struct divides { CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const { lhs /= rhs; return lhs; } }; template struct negate { CUTLASS_HOST_DEVICE T operator()(T lhs) const { return -lhs; } }; /// Greater equal template struct greater_equal { CUTLASS_HOST_DEVICE bool operator()(T const &lhs, T const &rhs) const { return (lhs >= rhs); } }; /// Greater template struct greater { CUTLASS_HOST_DEVICE bool operator()(T const &lhs, T const &rhs) const { return (lhs > rhs); } }; /// Less equal template struct less_equal { CUTLASS_HOST_DEVICE bool operator()(T const &lhs, T const &rhs) const { return (lhs <= rhs); } }; /// Less template struct less { CUTLASS_HOST_DEVICE bool operator()(T const &lhs, T const &rhs) const { return (lhs < rhs); } }; template struct maximum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { return (lhs < rhs ? rhs : lhs); } }; template <> struct maximum { CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const { return fmaxf(lhs, rhs); } }; template struct minimum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { return (rhs < lhs ? rhs : lhs); } }; template <> struct minimum { CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const { return fminf(lhs, rhs); } }; /// Fused multiply-add template struct multiply_add { CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const { return C(a) * C(b) + c; } }; /// Fused multiply-add template struct multiply_add_relu0 { CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const { maximum mx; return mx(C(a) * C(b) + c, C(0)); } }; /// Fused multiply-add template struct and_add { CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b, T const &c) const { return ((a & b) + c); } }; /// Fused multiply-add template struct xor_add { CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b, T const &c) const { return ((a ^ b) + c); } }; template struct conjugate { CUTLASS_HOST_DEVICE T operator()(T const &a) const { return a; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// template struct logical_and { CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b) const { return ((a && b) ? T(1) : T()); } }; template struct logical_or { CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b) const { return ((a || b) ? T(1) : T()); } }; template struct logical_not { CUTLASS_HOST_DEVICE T operator()(T const &a) const { return T(!(a)); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// template struct bit_and { CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b) const { return a & b; } }; template struct bit_or { CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b) const { return a | b; } }; template struct bit_not { CUTLASS_HOST_DEVICE T operator()(T const &a) const { return ~a; } }; template struct bit_xor { CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b) const { return a ^ b; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // Partial specializations for Arrays template struct bit_and> { CUTLASS_HOST_DEVICE Array operator()(Array const &a, Array const &b) const { using ArrayType = Array; using Storage = typename ArrayType::Storage; ArrayType result; Storage *result_data = result.raw_data(); Storage const *a_data = a.raw_data(); Storage const *b_data = b.raw_data(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ArrayType::kStorageElements; ++i) { result_data[i] = (a_data[i] & b_data[i]); } return result; } }; // Partial specializations for Arrays template struct bit_or> { CUTLASS_HOST_DEVICE Array operator()(Array const &a, Array const &b) const { using ArrayType = Array; using Storage = typename ArrayType::Storage; ArrayType result; Storage *result_data = result.raw_data(); Storage const *a_data = a.raw_data(); Storage const *b_data = b.raw_data(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ArrayType::kStorageElements; ++i) { result_data[i] = (a_data[i] | b_data[i]); } return result; } }; // Partial specializations for Arrays template struct bit_not> { CUTLASS_HOST_DEVICE Array operator()(Array const &a) const { using ArrayType = Array; using Storage = typename ArrayType::Storage; ArrayType result; Storage *result_data = result.raw_data(); Storage const *a_data = a.raw_data(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ArrayType::kStorageElements; ++i) { result_data[i] = (~a_data[i]); } return result; } }; // Partial specializations for Arrays template struct bit_xor> { CUTLASS_HOST_DEVICE Array operator()(Array const &a, Array const &b) const { using ArrayType = Array; using Storage = typename ArrayType::Storage; ArrayType result; Storage *result_data = result.raw_data(); Storage const *a_data = a.raw_data(); Storage const *b_data = b.raw_data(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ArrayType::kStorageElements; ++i) { result_data[i] = (a_data[i] ^ b_data[i]); } return result; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// template struct conjugate> { CUTLASS_HOST_DEVICE complex operator()(complex const &a) const { return conj(a); } }; template struct conjugate > { CUTLASS_HOST_DEVICE Array operator()(Array const &a) const { conjugate conj_op; Array ca; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { ca[i] = conj_op(a[i]); } return ca; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specialization for complex to target four scalar fused multiply-adds. // ///////////////////////////////////////////////////////////////////////////////////////////////// /// Fused multiply-add template struct multiply_add, complex, complex> { CUTLASS_HOST_DEVICE complex operator()( complex const &a, complex const &b, complex const &c) const { T real = c.real(); T imag = c.imag(); real += a.real() * b.real(); real += -a.imag() * b.imag(); imag += a.real() * b.imag(); imag += a.imag () * b.real(); return complex{ real, imag }; } }; /// Fused multiply-add template struct multiply_add, T, complex> { CUTLASS_HOST_DEVICE complex operator()( complex const &a, T const &b, complex const &c) const { T real = c.real(); T imag = c.imag(); real += a.real() * b; imag += a.imag () * b; return complex{ real, imag }; } }; /// Fused multiply-add template struct multiply_add, complex> { CUTLASS_HOST_DEVICE complex operator()( T const &a, complex const &b, complex const &c) const { T real = c.real(); T imag = c.imag(); real += a * b.real(); imag += a * b.imag(); return complex{ real, imag }; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for Array // ///////////////////////////////////////////////////////////////////////////////////////////////// template struct absolute_value_op< Array > { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs) const { Array result; absolute_value_op scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i]); } return result; } }; template struct plus> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; plus scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], rhs[i]); } return result; } CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, T const &scalar) const { Array result; plus scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], scalar); } return result; } CUTLASS_HOST_DEVICE Array operator()( T const &scalar, Array const &rhs) const { Array result; plus scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(scalar, rhs[i]); } return result; } }; template struct minus> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; minus scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], rhs[i]); } return result; } CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, T const &scalar) const { Array result; minus scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], scalar); } return result; } CUTLASS_HOST_DEVICE Array operator()( T const &scalar, Array const &rhs) const { Array result; minus scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(scalar, rhs[i]); } return result; } }; template struct multiplies> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; multiplies scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], rhs[i]); } return result; } CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, T const &scalar) const { Array result; multiplies scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], scalar); } return result; } CUTLASS_HOST_DEVICE Array operator()( T const &scalar, Array const &rhs) const { Array result; multiplies scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(scalar, rhs[i]); } return result; } }; template struct divides> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; divides scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], rhs[i]); } return result; } CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, T const &scalar) const { Array result; divides scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], scalar); } return result; } CUTLASS_HOST_DEVICE Array operator()( T const &scalar, Array const &rhs) const { Array result; divides scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(scalar, rhs[i]); } return result; } }; template struct maximum> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], rhs[i]); } return result; } CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, T const &scalar) const { Array result; maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], scalar); } return result; } CUTLASS_HOST_DEVICE Array operator()( T const &scalar, Array const &rhs) const { Array result; maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(scalar, rhs[i]); } return result; } }; template struct minimum> { CUTLASS_HOST_DEVICE static T scalar_op(T const &lhs, T const &rhs) { return (rhs < lhs ? rhs : lhs); } CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], rhs[i]); } return result; } CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, T const &scalar) const { Array result; minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i], scalar); } return result; } CUTLASS_HOST_DEVICE Array operator()( T const &scalar, Array const &rhs) const { Array result; minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(scalar, rhs[i]); } return result; } }; template struct negate> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs) const { Array result; negate scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(lhs[i]); } return result; } }; /// Fused multiply-add template struct multiply_add, Array, Array> { CUTLASS_HOST_DEVICE Array operator()(Array const &a, Array const &b, Array const &c) const { Array result; multiply_add scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(a[i], b[i], c[i]); } return result; } CUTLASS_HOST_DEVICE Array operator()(Array const &a, T const &scalar, Array const &c) const { Array result; multiply_add scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(a[i], scalar, c[i]); } return result; } CUTLASS_HOST_DEVICE Array operator()(T const &scalar, Array const &b, Array const &c) const { Array result; multiply_add scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = scalar_op(scalar, b[i], c[i]); } return result; } }; /// Fused multiply-add-relu0 template struct multiply_add_relu0, Array, Array> { CUTLASS_HOST_DEVICE Array operator()(Array const &a, Array const &b, Array const &c) const { Array result; multiply_add scalar_op; maximum mx; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); } return result; } CUTLASS_HOST_DEVICE Array operator()(Array const &a, T const &scalar, Array const &c) const { Array result; multiply_add scalar_op; maximum mx; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); } return result; } CUTLASS_HOST_DEVICE Array operator()(T const &scalar, Array const &b, Array const &c) const { Array result; multiply_add scalar_op; maximum mx; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); } return result; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for Array targeting SIMD instructions in device code. // ///////////////////////////////////////////////////////////////////////////////////////////////// template struct plus> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs[i] + rhs[i]; } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(half_t const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); } if (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs + rhs[i]; } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, half_t const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs[i] + rhs; } #endif return result; } }; template struct minus> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs[i] - rhs[i]; } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(half_t const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); } if (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs - rhs[i]; } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, half_t const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs[i] - rhs; } #endif return result; } }; template struct multiplies> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs[i] * rhs[i]; } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(half_t const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); } if (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmul( reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs * rhs[i]; } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, half_t const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hmul( a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs[i] * rhs; } #endif return result; } }; template struct divides> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hdiv( a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs[i] / rhs[i]; } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(half_t const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); } if (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hdiv( reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs / rhs[i]; } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, half_t const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hdiv( a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = lhs[i] / rhs; } #endif return result; } }; template struct negate> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hneg2(source_ptr[i]); } if (N % 2) { half_t x = lhs[N - 1]; __half lhs_val = -reinterpret_cast<__half const &>(x); result[N - 1] = reinterpret_cast(lhs_val); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = -lhs[i]; } #endif return result; } }; /// Fused multiply-add template struct multiply_add, Array, Array> { CUTLASS_HOST_DEVICE Array operator()( Array const &a, Array const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); __half d_residual = __hfma( a_residual_ptr[N - 1], b_residual_ptr[N - 1], c_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else multiply_add op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = op(a[i], b[i], c[i]); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( half_t const &a, Array const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); } if (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); __half d_residual = __hfma( reinterpret_cast<__half const &>(a), b_residual_ptr[N - 1], c_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else multiply_add op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = op(a, b[i], c[i]); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( Array const &a, half_t const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); __half d_residual = __hfma( a_residual_ptr[N - 1], reinterpret_cast<__half const &>(b), c_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else multiply_add op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = op(a[i], b, c[i]); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( Array const &a, Array const &b, half_t const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); __half d_residual = __hfma( a_residual_ptr[N - 1], b_residual_ptr[N - 1], reinterpret_cast<__half const &>(c)); result[N - 1] = reinterpret_cast(d_residual); } #else multiply_add op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = op(a[i], b[i], c); } #endif return result; } }; /// Fused multiply-add-relu0 template struct multiply_add_relu0, Array, Array> { CUTLASS_HOST_DEVICE Array operator()( Array const &a, Array const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); __half d_residual = __hfma_relu( a_residual_ptr[N - 1], b_residual_ptr[N - 1], c_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else multiply_add op; maximum mx; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( half_t const &a, Array const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); } if (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); __half d_residual = __hfma_relu( reinterpret_cast<__half const &>(a), b_residual_ptr[N - 1], c_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else multiply_add op; maximum mx; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = mx(op(a, b[i], c[i]), half_t(0)); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( Array const &a, half_t const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); __half d_residual = __hfma_relu( a_residual_ptr[N - 1], reinterpret_cast<__half const &>(b), c_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else multiply_add op; maximum mx; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = mx(op(a[i], b, c[i]), half_t(0)); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( Array const &a, Array const &b, half_t const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); __half d_residual = __hfma_relu( a_residual_ptr[N - 1], b_residual_ptr[N - 1], reinterpret_cast<__half const &>(c)); result[N - 1] = reinterpret_cast(d_residual); } #else multiply_add op; maximum mx; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = mx(op(a[i], b[i], c), half_t(0)); } #endif return result; } }; template struct minimum> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmin( a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(half_t const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); } if (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmin( reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = (rhs[i] < lhs ? rhs[i] : lhs); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, half_t const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hmin( a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = (rhs < lhs[i] ? rhs : lhs[i]); } #endif return result; } }; template struct maximum> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmax( a_residual_ptr[N - 1], b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(half_t const & lhs, Array const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); } if (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmax( reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = (lhs < rhs[i] ? rhs[i] : lhs); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, half_t const &rhs) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); } if (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hmax( a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); result[N - 1] = reinterpret_cast(d_residual); } #else CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = (lhs[i] < rhs ? rhs : lhs[i]); } #endif return result; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Fused multiply-add template struct multiply_add, Array, Array> { CUTLASS_HOST_DEVICE Array operator()( Array const &a, Array const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) unsigned *result_ptr = reinterpret_cast(&result); unsigned const *a_ptr = reinterpret_cast(&a); unsigned const *b_ptr = reinterpret_cast(&b); unsigned const *c_ptr = reinterpret_cast(&c); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(result_ptr[i]) : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) ); } if (N % 2) { uint16_t *result_ptr = reinterpret_cast(&result); uint16_t const *a_residual_ptr = reinterpret_cast(&a); uint16_t const *b_residual_ptr = reinterpret_cast(&b); uint16_t const *c_residual_ptr = reinterpret_cast(&c); asm ("fma.rn.bf16 %0, %1, %2, %3;\n" : "=h"(result_ptr[N - 1]) : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) ); } #else multiply_add op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = op(a[i], b[i], c[i]); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( bfloat16_t const &a, Array const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) unsigned *result_ptr = reinterpret_cast(&result); unsigned const *b_ptr = reinterpret_cast(&b); unsigned const *c_ptr = reinterpret_cast(&c); unsigned a_packed = static_cast(a.raw()); a_packed = (a_packed | (a_packed << 16)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(result_ptr[i]) : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) ); } if (N % 2) { uint16_t *result_ptr = reinterpret_cast(&result); uint16_t const *a_residual_ptr = reinterpret_cast(&a); uint16_t const *b_residual_ptr = reinterpret_cast(&b); uint16_t const *c_residual_ptr = reinterpret_cast(&c); asm ("fma.rn.bf16 %0, %1, %2, %3;\n" : "=h"(result_ptr[N - 1]) : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) ); } #else multiply_add op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = op(a, b[i], c[i]); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( Array const &a, bfloat16_t const &b, Array const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) unsigned *result_ptr = reinterpret_cast(&result); unsigned const *a_ptr = reinterpret_cast(&a); unsigned const *c_ptr = reinterpret_cast(&c); unsigned b_packed = static_cast(b.raw()); b_packed = (b_packed | (b_packed << 16)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(result_ptr[i]) : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) ); } if (N % 2) { uint16_t *result_ptr = reinterpret_cast(&result); uint16_t const *a_residual_ptr = reinterpret_cast(&a); uint16_t const *b_residual_ptr = reinterpret_cast(&b); uint16_t const *c_residual_ptr = reinterpret_cast(&c); asm ("fma.rn.bf16 %0, %1, %2, %3;\n" : "=h"(result_ptr[N - 1]) : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) ); } #else multiply_add op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = op(a[i], b, c[i]); } #endif return result; } CUTLASS_HOST_DEVICE Array operator()( Array const &a, Array const &b, bfloat16_t const &c) const { Array result; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) unsigned *result_ptr = reinterpret_cast(&result); unsigned const *a_ptr = reinterpret_cast(&a); unsigned const *b_ptr = reinterpret_cast(&b); unsigned c_packed = static_cast(c.raw()); c_packed = (c_packed | (c_packed << 16)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(result_ptr[i]) : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) ); } if (N % 2) { uint16_t *result_ptr = reinterpret_cast(&result); uint16_t const *a_residual_ptr = reinterpret_cast(&a); uint16_t const *b_residual_ptr = reinterpret_cast(&b); uint16_t const *c_residual_ptr = reinterpret_cast(&c); asm ("fma.rn.bf16 %0, %1, %2, %3;\n" : "=h"(result_ptr[N - 1]) : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) ); } #else multiply_add op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { result[i] = op(a[i], b[i], c); } #endif return result; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_HOST_DEVICE Array operator+(Array const &lhs, Array const &rhs) { plus> op; return op(lhs, rhs); } template CUTLASS_HOST_DEVICE Array operator-(Array const &lhs, Array const &rhs) { minus> op; return op(lhs, rhs); } template CUTLASS_HOST_DEVICE Array operator-(Array const &lhs) { negate> op; return op(lhs); } template CUTLASS_HOST_DEVICE Array operator*(Array const &lhs, Array const &rhs) { multiplies> op; return op(lhs, rhs); } template CUTLASS_HOST_DEVICE Array operator*(T lhs, Array const &rhs) { multiplies> op; return op(lhs, rhs); } template CUTLASS_HOST_DEVICE Array operator*(Array const &lhs, T rhs) { multiplies> op; return op(lhs, rhs); } template CUTLASS_HOST_DEVICE Array operator/(Array const &lhs, Array const &rhs) { divides> op; return op(lhs, rhs); } template CUTLASS_HOST_DEVICE Array fma(Array const &a, Array const &b, Array const &c) { multiply_add> op; return op(a, b, c); } template CUTLASS_HOST_DEVICE Array fma(T a, Array const &b, Array const &c) { multiply_add> op; return op(a, b, c); } template CUTLASS_HOST_DEVICE Array fma(Array const &a, T b, Array const &c) { multiply_add> op; return op(a, b, c); } template CUTLASS_HOST_DEVICE Array fma(Array const &a, Array const &b, T c) { multiply_add> op; return op(a, b, c); } ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for Quaternion fused multiply-add // ///////////////////////////////////////////////////////////////////////////////////////////////// template struct multiply_add, Quaternion, Quaternion> { CUTLASS_HOST_DEVICE Quaternion operator()( Quaternion const &a, Quaternion const &b, Quaternion const &c) const { T x = c.x(); T y = c.y(); T z = c.z(); T w = c.w(); x += a.w() * b.x(); x += b.w() * a.x(); x += a.y() * b.z(); x += -a.z() * b.y(), y += a.w() * b.y(); y += b.w() * a.y(); y += a.z() * b.x(); y += -a.x() * b.z(); z += a.w() * b.z(); z += b.w() * a.z(); z += a.x() * b.y(); z += -a.y() * b.x(); w += a.w() * b.w(); w += -a.x() * b.x(); w += -a.y() * b.y(); w += -a.z() * b.z(); return cutlass::make_Quaternion(x, y, z, w); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////