Softmax (#546)
* add test layernorm g-mem version * Delete include/configure directory * Delete examples/test_layernorm directory * Update gemm_with_softmax.h * Update gemm_softmax.cu * Update linear_combination.h * Update fast_math.h * remove redundant vars Co-authored-by: yujia.zhai <yujia.zhai@bytedance.com> Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
parent
e45e773436
commit
04a9777b87
@ -55,6 +55,7 @@
|
|||||||
#include "cutlass/util/reference/host/error_metrics.h"
|
#include "cutlass/util/reference/host/error_metrics.h"
|
||||||
#include "cutlass/util/tensor_view_io.h"
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#include "gemm_with_softmax.h"
|
#include "gemm_with_softmax.h"
|
||||||
@ -204,14 +205,24 @@ struct Testbed {
|
|||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
|
||||||
|
/// Linear scaling operator
|
||||||
|
using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination<
|
||||||
|
ElementC,
|
||||||
|
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||||
|
ElementCompute,
|
||||||
|
ElementCompute
|
||||||
|
>;
|
||||||
|
|
||||||
using GemmSoftmax = cutlass::GemmSoftmax<
|
using GemmSoftmax = cutlass::GemmSoftmax<
|
||||||
ElementA, LayoutA,
|
ElementA, LayoutA,
|
||||||
ElementB, LayoutB,
|
ElementB, LayoutB,
|
||||||
ElementC,
|
ElementC,
|
||||||
ElementCompute
|
ElementCompute,
|
||||||
|
EpilogueFunctorOp
|
||||||
>;
|
>;
|
||||||
|
|
||||||
using ElementN = typename GemmSoftmax::ElementN;
|
using ElementNorm = typename GemmSoftmax::ElementNorm;
|
||||||
|
using ElementSum = typename GemmSoftmax::ElementSum;
|
||||||
using LayoutC = typename GemmSoftmax::LayoutC;
|
using LayoutC = typename GemmSoftmax::LayoutC;
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -224,13 +235,16 @@ struct Testbed {
|
|||||||
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
||||||
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
||||||
cutlass::HostTensor<ElementD, LayoutC> tensor_D;
|
cutlass::HostTensor<ElementD, LayoutC> tensor_D;
|
||||||
cutlass::HostTensor<ElementN, LayoutC> tensor_N;
|
cutlass::HostTensor<ElementNorm, LayoutC> tensor_N;
|
||||||
|
cutlass::HostTensor<ElementSum, LayoutC> tensor_S;
|
||||||
cutlass::HostTensor<ElementSoftmax, LayoutC> tensor_Softmax;
|
cutlass::HostTensor<ElementSoftmax, LayoutC> tensor_Softmax;
|
||||||
|
|
||||||
cutlass::HostTensor<ElementD, LayoutC> reference_D;
|
cutlass::HostTensor<ElementD, LayoutC> reference_D;
|
||||||
cutlass::HostTensor<ElementN, LayoutC> reference_N;
|
cutlass::HostTensor<ElementNorm, LayoutC> reference_N;
|
||||||
cutlass::HostTensor<ElementSoftmax, LayoutC> reference_Softmax;
|
cutlass::HostTensor<ElementSoftmax, LayoutC> reference_Softmax;
|
||||||
|
|
||||||
|
int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Methods
|
// Methods
|
||||||
//
|
//
|
||||||
@ -247,7 +261,8 @@ struct Testbed {
|
|||||||
tensor_C.reset({options.problem_size.m(), options.problem_size.n()});
|
tensor_C.reset({options.problem_size.m(), options.problem_size.n()});
|
||||||
tensor_D.reset({options.problem_size.m(), options.problem_size.n()});
|
tensor_D.reset({options.problem_size.m(), options.problem_size.n()});
|
||||||
|
|
||||||
tensor_N.reset({options.problem_size.m(), 1});
|
tensor_N.reset({block_num, options.problem_size.m()});
|
||||||
|
tensor_S.reset({block_num, options.problem_size.m()});
|
||||||
tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()});
|
tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()});
|
||||||
|
|
||||||
reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false);
|
reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false);
|
||||||
@ -342,7 +357,7 @@ struct Testbed {
|
|||||||
|
|
||||||
cutlass::reference::host::TensorFill(
|
cutlass::reference::host::TensorFill(
|
||||||
reference_N.host_view(),
|
reference_N.host_view(),
|
||||||
ElementN()
|
ElementNorm()
|
||||||
);
|
);
|
||||||
|
|
||||||
cutlass::reference::host::TensorFill(
|
cutlass::reference::host::TensorFill(
|
||||||
@ -354,6 +369,7 @@ struct Testbed {
|
|||||||
tensor_B.sync_device();
|
tensor_B.sync_device();
|
||||||
tensor_D.sync_device();
|
tensor_D.sync_device();
|
||||||
tensor_N.sync_device();
|
tensor_N.sync_device();
|
||||||
|
tensor_S.sync_device();
|
||||||
tensor_Softmax.sync_device();
|
tensor_Softmax.sync_device();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,6 +393,7 @@ struct Testbed {
|
|||||||
ElementCompute(options.beta)
|
ElementCompute(options.beta)
|
||||||
},
|
},
|
||||||
tensor_N.device_ref(),
|
tensor_N.device_ref(),
|
||||||
|
tensor_S.device_ref(),
|
||||||
tensor_Softmax.device_ref()
|
tensor_Softmax.device_ref()
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -420,7 +437,7 @@ struct Testbed {
|
|||||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||||
reference_N.at({m, 0}) = reference_D.at({m, 0});
|
reference_N.at({m, 0}) = reference_D.at({m, 0});
|
||||||
for (int n = 1; n < options.problem_size.n(); ++n) {
|
for (int n = 1; n < options.problem_size.n(); ++n) {
|
||||||
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementN(reference_D.at({m, n})));
|
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(reference_D.at({m, n})));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -454,6 +471,20 @@ struct Testbed {
|
|||||||
std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n";
|
std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool verify_tensor_N(cutlass::HostTensor<ElementNorm, LayoutC> tensor_N, \
|
||||||
|
cutlass::HostTensor<ElementNorm, LayoutC> reference_N) {
|
||||||
|
|
||||||
|
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||||
|
float diff = (float)(tensor_N.at({0, m}) - reference_N.at({m, 0}));
|
||||||
|
if (fabs(diff) > options.tolerance) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/// Verifies the reference matches
|
/// Verifies the reference matches
|
||||||
bool verify() {
|
bool verify() {
|
||||||
|
|
||||||
@ -489,22 +520,7 @@ struct Testbed {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!verified_N) {
|
if (!verified_N) {
|
||||||
|
verified_N = verify_tensor_N(tensor_N, reference_N);
|
||||||
double norm_diff = cutlass::reference::host::TensorNormDiff(
|
|
||||||
tensor_N.host_view(),
|
|
||||||
reference_N.host_view());
|
|
||||||
|
|
||||||
double norm_reference = cutlass::reference::host::TensorNorm(
|
|
||||||
reference_N.host_view());
|
|
||||||
|
|
||||||
double rel_error = norm_diff / norm_reference;
|
|
||||||
|
|
||||||
if (rel_error > kThreshold) {
|
|
||||||
std::cerr << "\n\nTensor N Relative error: " << rel_error << std::endl;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
verified_N = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!verified_Softmax) {
|
if (!verified_Softmax) {
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -71,7 +71,7 @@ public:
|
|||||||
using ElementCompute = ElementCompute_;
|
using ElementCompute = ElementCompute_;
|
||||||
|
|
||||||
static int const kCount = Count;
|
static int const kCount = Count;
|
||||||
|
static const ScaleType::Kind kScale = Scale;
|
||||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||||
|
@ -718,11 +718,11 @@ double fast_exp(double x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
float fast_exp(half_t x) {
|
half_t fast_exp(half_t x) {
|
||||||
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
|
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
|
||||||
return ::hexp(x.to_half());
|
return (half_t)(::hexp(x.to_half()));
|
||||||
#else
|
#else
|
||||||
return fast_exp(float(x));
|
return (half_t)(fast_exp(float(x)));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -908,4 +908,3 @@ T absolute_value(T x) {
|
|||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user