Remove old code in utils.h (#511)
This commit is contained in:
parent
866a9d33f9
commit
dd8a754915
@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline __device__ float2 half2_unpack(uint32_t a);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
|
|
||||||
return __half22float2(reinterpret_cast<__half2 (&)>(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
||||||
template <>
|
|
||||||
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
|
|
||||||
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// Convert two half2's or bf162's into float, then take their dot product.
|
|
||||||
template <typename T>
|
|
||||||
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
|
|
||||||
float2 af = flash::half2_unpack<T>(a);
|
|
||||||
float2 bf = flash::half2_unpack<T>(b);
|
|
||||||
return af.x * bf.x + af.y * bf.y;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
|
||||||
template<typename T>
|
|
||||||
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
|
||||||
float sum;
|
|
||||||
sum = flash::hfma2_to_float<T>(a.x, b.x);
|
|
||||||
sum += flash::hfma2_to_float<T>(a.y, b.y);
|
|
||||||
sum += flash::hfma2_to_float<T>(a.z, b.z);
|
|
||||||
sum += flash::hfma2_to_float<T>(a.w, b.w);
|
|
||||||
return sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct MaxOp {
|
struct MaxOp {
|
||||||
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user