Remove old code in utils.h (#511)
This commit is contained in:
parent
866a9d33f9
commit
dd8a754915
@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ float2 half2_unpack(uint32_t a);
|
||||
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
|
||||
return __half22float2(reinterpret_cast<__half2 (&)>(a));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
|
||||
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert two half2's or bf162's into float, then take their dot product.
|
||||
template <typename T>
|
||||
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
|
||||
float2 af = flash::half2_unpack<T>(a);
|
||||
float2 bf = flash::half2_unpack<T>(b);
|
||||
return af.x * bf.x + af.y * bf.y;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
||||
template<typename T>
|
||||
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
||||
float sum;
|
||||
sum = flash::hfma2_to_float<T>(a.x, b.x);
|
||||
sum += flash::hfma2_to_float<T>(a.y, b.y);
|
||||
sum += flash::hfma2_to_float<T>(a.z, b.z);
|
||||
sum += flash::hfma2_to_float<T>(a.w, b.w);
|
||||
return sum;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
|
||||
Loading…
Reference in New Issue
Block a user