cutlass/include/cutlass/quaternion.h
ANIKET SHIVAM 4575443d44
CUTLASS 3.2 (#1024)
* CUTLASS 3.2
2023-08-07 20:50:32 -04:00

753 lines
20 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a densely packed quaternion object intended for storing data in registers and
executing quaternion operations within a CUDA or host thread.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/array.h"
#include "cutlass/real.h"
#include "cutlass/coord.h"
#include "cutlass/matrix.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/vector.h"
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Quaternion: xi + yj + zk + w
template <
typename Element_ = float ///< element type
>
class Quaternion : public Array<Element_, 4> {
public:
/// Logical rank of tensor index space
static int const kRank = 1;
/// Number of elements
static int const kExtent = 4;
/// Base class is a four-element array
using Base = Array<Element_, kExtent>;
/// Element type
using Element = typename Base::Element;
/// Reference type to an element
using Reference = typename Base::reference;
/// Index type
using Index = int;
/// Quaternion storage - imaginary part
static int const kX = 0;
/// Quaternion storage - imaginary part
static int const kY = 1;
/// Quaternion storage - imaginary part
static int const kZ = 2;
/// Quaternion storage - real part
static int const kW = 3;
public:
//
// Methods
//
/// Constructs a quaternion q = 0
CUTLASS_HOST_DEVICE
Quaternion() {
Base::at(kX) = Element();
Base::at(kY) = Element();
Base::at(kZ) = Element();
Base::at(kW) = Element();
}
/// Constructs a quaternion q = w + 0*i + 0*j + 0*k
CUTLASS_HOST_DEVICE
Quaternion(
Element w_
) {
Base::at(kX) = Element();
Base::at(kY) = Element();
Base::at(kZ) = Element();
Base::at(kW) = w_;
}
/// Constructs a quaternion q = w + x*i + y*j + z*k
CUTLASS_HOST_DEVICE
Quaternion(
Element x_,
Element y_,
Element z_,
Element w_
) {
Base::at(kX) = x_;
Base::at(kY) = y_;
Base::at(kZ) = z_;
Base::at(kW) = w_;
}
/// Constructs a quaternion from a vector representing the imaginary part and a real number
CUTLASS_HOST_DEVICE
Quaternion(
Matrix3x1<Element> const &imag_,
Element w_ = Element()
) {
Base::at(kX) = imag_[0];
Base::at(kY) = imag_[1];
Base::at(kZ) = imag_[2];
Base::at(kW) = w_;
}
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Reference at(Index idx) const {
return Base::at(idx);
}
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Reference at(Index idx) {
return Base::at(idx);
}
/// Accesses the x element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Element x() const {
return Base::at(kX);
}
/// Accesses the x element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Reference x() {
return Base::at(kX);
}
/// Accesses the y element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Element y() const {
return Base::at(kY);
}
/// Accesses the y element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Reference y() {
return Base::at(kY);
}
/// Accesses the z element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Element z() const {
return Base::at(kZ);
}
/// Accesses the z element of the imaginary part of the quaternion
CUTLASS_HOST_DEVICE
Reference z() {
return Base::at(kZ);
}
/// Accesses the real part of the quaternion
CUTLASS_HOST_DEVICE
Element w() const {
return Base::at(kW);
}
/// Accesses the real part of the quaternion
CUTLASS_HOST_DEVICE
Reference w() {
return Base::at(kW);
}
/// Returns the pure imaginary part of the quaternion as a 3-vector
CUTLASS_HOST_DEVICE
Matrix3x1<Element> pure() const {
return Matrix3x1<Element>(x(), y(), z());
}
/// Returns a quaternion representation of a spatial rotation given a unit-length axis and
/// a rotation in radians.
CUTLASS_HOST_DEVICE
static Quaternion<Element> rotation(
Matrix3x1<Element> const &axis_unit, ///< axis of rotation (assumed to be unit length)
Element theta) { ///< angular rotation in radians
Element s = fast_sin(theta / Element(2));
return Quaternion(
s * axis_unit[0],
s * axis_unit[1],
s * axis_unit[2],
fast_cos(theta / Element(2))
);
}
/// Returns a quaternion representation of a spatial rotation represented as a
/// unit-length rotation axis (r_x, r_y, r_z) and an angular rotation in radians
CUTLASS_HOST_DEVICE
static Quaternion<Element> rotation(
Element r_x,
Element r_y,
Element r_z,
Element theta) { ///< angular rotation in radians
return rotation({r_x, r_y, r_z}, theta);
}
/// Geometric rotation of a 3-element vector
CUTLASS_HOST_DEVICE
Matrix3x1<Element> rotate(Matrix3x1<Element> const &rhs) const {
return (*this * Quaternion<Element>(rhs, 0) * reciprocal(*this)).pure();
}
/// Inverse rotation operation
CUTLASS_HOST_DEVICE
Matrix3x1<Element> rotate_inv(Matrix3x1<Element> const &rhs) const {
return (reciprocal(*this) * Quaternion<Element>(rhs, 0) * *this).pure();
}
/// Rotates a 3-vector assuming this is a unit quaternion (a spinor)
CUTLASS_HOST_DEVICE
Matrix3x1<Element> spinor(Matrix3x1<Element> const &rhs) const {
return (*this * Quaternion<Element>(rhs, 0) * conj(*this)).pure();
}
/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor)
CUTLASS_HOST_DEVICE
Matrix3x1<Element> spinor_inv(Matrix3x1<Element> const &rhs) const {
return (conj(*this) * Quaternion<Element>(rhs, 0) * *this).pure();
}
/// In-place addition
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator+=(Quaternion<Element> const &rhs) {
*this = (*this + rhs);
return *this;
}
/// In-place subtraction
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator-=(Quaternion<Element> const &rhs) {
*this = (*this - rhs);
return *this;
}
/// In-place multiplication
template <typename T>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator*=(Quaternion<Element> const &rhs) {
*this = (*this * rhs);
return *this;
}
/// Scalar multiplication
template <typename T>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator*=(Element s) {
*this = (*this * s);
return *this;
}
/// In-place Division
template <typename T>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator/=(Quaternion<Element> const &rhs) {
*this = (*this / rhs);
return *this;
}
/// In-place Division
template <typename T>
CUTLASS_HOST_DEVICE
Quaternion<Element> &operator/=(Element s) {
*this = (*this / s);
return *this;
}
/// Computes a 3x3 rotation matrix (row-major representation)
CUTLASS_HOST_DEVICE
Matrix3x3<Element> as_rotation_matrix_3x3() const {
Matrix3x3<Element> m(
w() * w() + x() * x() - y() * y() - z() * z(),
2 * x() * y() - 2 * w() * z(),
2 * x() * z() + 2 * w() * y(),
2 * x() * y() + 2 * w() * z(),
w() * w() - x() * x() + y() * y() - z() * z(),
2 * y() * z() - 2 * w() * x(),
2 * x() * z() - 2 * w() * y(),
2 * y() * z() + 2 * w() * x(),
w() * w() - x() * x() - y() * y() + z() * z()
);
return m;
}
/// Computes a 4x4 rotation matrix (row-major representation)
CUTLASS_HOST_DEVICE
Matrix4x4<Element> as_rotation_matrix_4x4() const {
Matrix4x4<Element> m = Matrix4x4<Element>::identity();
m.set_slice_3x3(as_rotation_matrix_3x3());
return m;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs a quaternion that is non-zero only in its real element.
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> make_Quaternion(
Element w) { ///< real part
return Quaternion<Element>(w);
}
/// Constructs a quaternion from a vector and real
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> make_Quaternion(
Matrix3x1<Element> const &imag, ///< imaginary party as a vector
Element w) { ///< real part
return Quaternion<Element>(imag, w);
}
/// Constructs a quaternion from a unit-length rotation axis and a rotation
/// angle in radians
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> make_QuaternionRotation(
Matrix3x1<Element> const &axis_unit, ///< rotation axis (unit-length)
Element w) { ///< rotation angle in radians
return Quaternion<Element>::rotation(axis_unit, w);
}
/// Constructs a quaternion q = xi + yj + zk + w
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> make_Quaternion(Element x, Element y, Element z, Element w) {
return Quaternion<Element>(x, y, z, w);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Returns the real part of the quaternion number
template <typename Element>
CUTLASS_HOST_DEVICE
Element const &real(Quaternion<Element> const &q) {
return q.w();
}
/// Returns the real part of the quaternion number
template <typename Element>
CUTLASS_HOST_DEVICE
Element &real(Quaternion<Element> &q) {
return q.w();
}
/// Returns the magnitude of the quaternion number
template <typename Element>
CUTLASS_HOST_DEVICE
Element abs(Quaternion<Element> const &q) {
return fast_sqrt(norm(q));
}
/// Quaternion conjugate
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> conj(Quaternion<Element> const &q) {
return make_Quaternion(
-q.x(),
-q.y(),
-q.z(),
q.w()
);
}
/// Computes the squared magnitude of the quaternion
template <typename Element>
CUTLASS_HOST_DEVICE
Element norm(Quaternion<Element> const &q) {
return q.x() * q.x() + q.y() * q.y() + q.z() * q.z() + q.w() * q.w();
}
/// Quaternion reciprocal
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> reciprocal(Quaternion<Element> const &q) {
Element nsq = norm(q);
return make_Quaternion(
-q.x() / nsq,
-q.y() / nsq,
-q.z() / nsq,
q.w() / nsq
);
}
/// Returns a unit-length quaternion
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> unit(Quaternion<Element> const &q) {
Element rcp_mag = Element(1) / abs(q);
return make_Quaternion(
q.x() * rcp_mag,
q.y() * rcp_mag,
q.z() * rcp_mag,
q.w() * rcp_mag
);
}
/// Quaternion exponential
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> exp(Quaternion<Element> const &q) {
Element exp_ = fast_exp(q.w());
Element imag_norm = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z());
Element sin_norm = fast_sin(imag_norm);
return make_Quaternion(
exp_ * q.x() * sin_norm / imag_norm,
exp_ * q.y() * sin_norm / imag_norm,
exp_ * q.z() * sin_norm / imag_norm,
exp_ * fast_cos(imag_norm)
);
}
/// Quaternion natural logarithm
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> log(Quaternion<Element> const &q) {
Element v = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z());
Element s = fast_acos(q.w() / abs(q)) / v;
return make_Quaternion(
q.x() * s,
q.y() * s,
q.z() * s,
fast_log(q.w())
);
}
/// Gets the rotation angle from a unit-length quaternion
template <typename Element>
CUTLASS_HOST_DEVICE
Element get_rotation_angle(Quaternion<Element> const &q_unit) {
return fast_acos(q_unit.w()) * Element(2);
}
/// Gets the rotation axis from a unit-length quaternion
template <typename Element>
CUTLASS_HOST_DEVICE
Matrix3x1<Element> get_rotation_axis(Quaternion<Element> const &q_unit) {
return q_unit.pure().unit();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Equality operator
template <typename Element>
CUTLASS_HOST_DEVICE
bool operator==(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return lhs.x() == rhs.x() &&
lhs.y() == rhs.y() &&
lhs.z() == rhs.z() &&
lhs.w() == rhs.w();
}
/// Inequality operator
template <typename Element>
CUTLASS_HOST_DEVICE
bool operator!=(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return !(lhs == rhs);
}
/// Quaternion scalar multiplication
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator*(Quaternion<Element> q, Element s) {
return make_Quaternion(
q.x() * s,
q.y() * s,
q.z() * s,
q.w() * s
);
}
/// Quaternion scalar multiplication
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator*(Element s, Quaternion<Element> const &q) {
return make_Quaternion(
s * q.x(),
s * q.y(),
s * q.z(),
s * q.w()
);
}
/// Quaternion scalar division
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator/(Quaternion<Element> const &q, Element s) {
return make_Quaternion(
q.x() / s,
q.y() / s,
q.z() / s,
q.w() / s
);
}
/// Quaternion unary negation
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator-(Quaternion<Element> const &q) {
return make_Quaternion(
-q.x(),
-q.y(),
-q.z(),
-q.w()
);
}
/// Quaternion addition
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator+(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return make_Quaternion(
lhs.x() + rhs.x(),
lhs.y() + rhs.y(),
lhs.z() + rhs.z(),
lhs.w() + rhs.w()
);
}
/// Quaternion subtraction
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator-(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return make_Quaternion(
lhs.x() - rhs.x(),
lhs.y() - rhs.y(),
lhs.z() - rhs.z(),
lhs.w() - rhs.w()
);
}
/// Quaternion product
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator*(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return make_Quaternion(
lhs.w() * rhs.x() + rhs.w() * lhs.x() + lhs.y() * rhs.z() - lhs.z() * rhs.y(),
lhs.w() * rhs.y() + rhs.w() * lhs.y() + lhs.z() * rhs.x() - lhs.x() * rhs.z(),
lhs.w() * rhs.z() + rhs.w() * lhs.z() + lhs.x() * rhs.y() - lhs.y() * rhs.x(),
lhs.w() * rhs.w() - lhs.x() * rhs.x() - lhs.y() * rhs.y() - lhs.z() * rhs.z()
);
}
/// Quaternion division
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator/(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return lhs * reciprocal(rhs);
}
/// Quaternion scalar division
template <typename Element>
CUTLASS_HOST_DEVICE
Quaternion<Element> operator/(Element s, Quaternion<Element> const &q) {
return s * reciprocal(q);
}
/// Comparison
template <typename Element>
CUTLASS_HOST_DEVICE
bool operator<(Quaternion<Element> const &lhs, Quaternion<Element> const &rhs) {
return true;
}
/// Rotates a 3-vector assuming this is a unit quaternion (a spinor). This avoids computing
/// a reciprocal.
template <typename Element>
CUTLASS_HOST_DEVICE
Matrix3x1<Element> spinor_rotation(
Quaternion<Element> const &spinor, /// unit-length quaternion
Matrix3x1<Element> const &rhs) { /// arbitrary 3-vector
return (spinor * Quaternion<Element>(rhs, 0) * conj(spinor)).pure();
}
/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor). This avoids computing
/// a reciprocal.
template <typename Element>
CUTLASS_HOST_DEVICE
Matrix3x1<Element> spinor_rotation_inv(
Quaternion<Element> const &spinor, /// unit-length quaternion
Matrix3x1<Element> const &rhs) { /// arbitrary 3-vector
return (conj(spinor) * Quaternion<Element>(rhs, 0) * spinor).pure();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Quaternion-valued type.
template <typename T>
struct RealType< Quaternion<T> > {
using Type = T;
/// Number of elements
static int const kExtent = Quaternion<T>::kExtent;
CUTLASS_HOST_DEVICE
static Quaternion<T> from_real(double x) {
return Quaternion<T>(static_cast<T>(x));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Factories
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
CUTLASS_HOST_DEVICE
cutlass::Quaternion<half_t> from_real<cutlass::Quaternion<half_t> >(double r) {
return cutlass::Quaternion<half_t>(half_t(r));
}
template <>
CUTLASS_HOST_DEVICE
cutlass::Quaternion<float> from_real<cutlass::Quaternion<float> >(double r) {
return cutlass::Quaternion<float>(float(r));
}
template <>
CUTLASS_HOST_DEVICE
cutlass::Quaternion<double> from_real<cutlass::Quaternion<double> >(double r) {
return cutlass::Quaternion<double>(r);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// functional.h numeric specializations
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct multiplies<Quaternion<T>> {
CUTLASS_HOST_DEVICE
Quaternion<T> operator()(Quaternion<T> lhs, Quaternion<T> const &rhs) const {
lhs = lhs * rhs;
return lhs;
}
};
/// Squares with optional conversion
template <typename T, typename Output>
struct magnitude_squared<Quaternion<T>, Output> {
CUTLASS_HOST_DEVICE
Output operator()(Quaternion<T> lhs) const {
multiplies<Output> mul_op;
Output y_w = Output(lhs.w());
Output y_x = Output(lhs.x());
Output y_y = Output(lhs.y());
Output y_z = Output(lhs.z());
return mul_op(y_w, y_w) + mul_op(y_x, y_x) + mul_op(y_y, y_y) + \
mul_op(y_z, y_z);
}
};
template <typename T>
struct multiply_add<Quaternion<T>, Quaternion<T>, Quaternion<T>> {
CUTLASS_HOST_DEVICE
Quaternion<T> operator()(
Quaternion<T> const &a,
Quaternion<T> const &b,
Quaternion<T> const &c) const {
T x = c.x();
T y = c.y();
T z = c.z();
T w = c.w();
x += a.w() * b.x();
x += b.w() * a.x();
x += a.y() * b.z();
x += -a.z() * b.y(),
y += a.w() * b.y();
y += b.w() * a.y();
y += a.z() * b.x();
y += -a.x() * b.z();
z += a.w() * b.z();
z += b.w() * a.z();
z += a.x() * b.y();
z += -a.y() * b.x();
w += a.w() * b.w();
w += -a.x() * b.x();
w += -a.y() * b.y();
w += -a.z() * b.z();
return cutlass::make_Quaternion(x, y, z, w);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////