2022-05-21 05:21:58 +08:00
|
|
|
/******************************************************************************
|
|
|
|
|
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
|
*
|
|
|
|
|
* Redistribution and use in source and binary forms, with or without
|
|
|
|
|
* modification, are permitted provided that the following conditions are met:
|
|
|
|
|
* * Redistributions of source code must retain the above copyright
|
|
|
|
|
* notice, this list of conditions and the following disclaimer.
|
|
|
|
|
* * Redistributions in binary form must reproduce the above copyright
|
|
|
|
|
* notice, this list of conditions and the following disclaimer in the
|
|
|
|
|
* documentation and/or other materials provided with the distribution.
|
|
|
|
|
* * Neither the name of the NVIDIA CORPORATION nor the
|
|
|
|
|
* names of its contributors may be used to endorse or promote products
|
|
|
|
|
* derived from this software without specific prior written permission.
|
|
|
|
|
*
|
|
|
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
|
|
|
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
|
|
|
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
|
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
|
|
|
|
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
|
|
|
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
|
|
|
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
|
|
|
|
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
|
|
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
|
|
|
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
*
|
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <assert.h>
|
|
|
|
|
#include <stdio.h>
|
|
|
|
|
#include <stdlib.h>
|
|
|
|
|
#include <cuda_runtime_api.h>
|
|
|
|
|
#include <cuda_fp16.h>
|
2022-07-10 09:39:02 +08:00
|
|
|
#include <cuda_bf16.h>
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
#define FMHA_CHECK_CUDA( call ) \
|
|
|
|
|
do { \
|
|
|
|
|
cudaError_t status_ = call; \
|
|
|
|
|
if( status_ != cudaSuccess ) { \
|
|
|
|
|
fprintf( stderr, \
|
|
|
|
|
"CUDA error (%s:%d): %s\n", \
|
|
|
|
|
__FILE__, \
|
|
|
|
|
__LINE__, \
|
|
|
|
|
cudaGetErrorString( status_ ) ); \
|
|
|
|
|
exit( 1 ); \
|
|
|
|
|
} \
|
|
|
|
|
} while( 0 )
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
2022-07-10 09:39:02 +08:00
|
|
|
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
|
|
|
|
|
if( dtype == DATA_TYPE_FP16 ) {
|
|
|
|
|
half x = __float2half_rn( norm );
|
|
|
|
|
uint16_t h = reinterpret_cast<const uint16_t &>( x );
|
|
|
|
|
ushort2 h2 = { h, h };
|
|
|
|
|
alpha = reinterpret_cast<const uint32_t &>( h2 );
|
2022-07-10 09:39:02 +08:00
|
|
|
} else if( dtype == DATA_TYPE_BF16 ) {
|
|
|
|
|
__nv_bfloat16 x = __float2bfloat16( norm );
|
|
|
|
|
uint16_t h = reinterpret_cast<const uint16_t &>( x );
|
|
|
|
|
ushort2 h2 = { h, h };
|
|
|
|
|
alpha = reinterpret_cast<const uint32_t &>( h2 );
|
2022-05-21 05:21:58 +08:00
|
|
|
} else if( dtype == DATA_TYPE_FP32 ) {
|
|
|
|
|
alpha = reinterpret_cast<const uint32_t &>( norm );
|
|
|
|
|
} else if( dtype == DATA_TYPE_INT32 ) {
|
|
|
|
|
int32_t inorm = static_cast<int32_t>( norm );
|
|
|
|
|
alpha = reinterpret_cast<const uint32_t &>( inorm );
|
|
|
|
|
} else {
|
|
|
|
|
assert( false );
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
|
|
|
|
|
switch( dtype ) {
|
|
|
|
|
case DATA_TYPE_FP32:
|
|
|
|
|
return n * 4;
|
|
|
|
|
case DATA_TYPE_FP16:
|
|
|
|
|
return n * 2;
|
2022-07-10 09:39:02 +08:00
|
|
|
case DATA_TYPE_BF16:
|
|
|
|
|
return n * 2;
|
2022-05-21 05:21:58 +08:00
|
|
|
case DATA_TYPE_INT32:
|
|
|
|
|
return n * 4;
|
|
|
|
|
case DATA_TYPE_INT8:
|
|
|
|
|
return n;
|
|
|
|
|
default:
|
|
|
|
|
assert( false );
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|