Remove old code in utils.h (#511)

This commit is contained in:
Sophia Wisdom 2023-09-01 15:32:09 -07:00 committed by GitHub
parent 866a9d33f9
commit dd8a754915
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = flash::half2_unpack<T>(a);
float2 bf = flash::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = flash::hfma2_to_float<T>(a.x, b.x);
sum += flash::hfma2_to_float<T>(a.y, b.y);
sum += flash::hfma2_to_float<T>(a.z, b.z);
sum += flash::hfma2_to_float<T>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
struct MaxOp { struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }