Softmax (#546)
* add test layernorm g-mem version * Delete include/configure directory * Delete examples/test_layernorm directory * Update gemm_with_softmax.h * Update gemm_softmax.cu * Update linear_combination.h * Update fast_math.h * remove redundant vars Co-authored-by: yujia.zhai <yujia.zhai@bytedance.com> Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
parent
e45e773436
commit
04a9777b87
@ -55,6 +55,7 @@
|
||||
#include "cutlass/util/reference/host/error_metrics.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "gemm_with_softmax.h"
|
||||
@ -204,14 +205,24 @@ struct Testbed {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
/// Linear scaling operator
|
||||
using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementCompute,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using GemmSoftmax = cutlass::GemmSoftmax<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
EpilogueFunctorOp
|
||||
>;
|
||||
|
||||
using ElementN = typename GemmSoftmax::ElementN;
|
||||
using ElementNorm = typename GemmSoftmax::ElementNorm;
|
||||
using ElementSum = typename GemmSoftmax::ElementSum;
|
||||
using LayoutC = typename GemmSoftmax::LayoutC;
|
||||
|
||||
//
|
||||
@ -224,13 +235,16 @@ struct Testbed {
|
||||
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
||||
cutlass::HostTensor<ElementD, LayoutC> tensor_D;
|
||||
cutlass::HostTensor<ElementN, LayoutC> tensor_N;
|
||||
cutlass::HostTensor<ElementNorm, LayoutC> tensor_N;
|
||||
cutlass::HostTensor<ElementSum, LayoutC> tensor_S;
|
||||
cutlass::HostTensor<ElementSoftmax, LayoutC> tensor_Softmax;
|
||||
|
||||
cutlass::HostTensor<ElementD, LayoutC> reference_D;
|
||||
cutlass::HostTensor<ElementN, LayoutC> reference_N;
|
||||
cutlass::HostTensor<ElementNorm, LayoutC> reference_N;
|
||||
cutlass::HostTensor<ElementSoftmax, LayoutC> reference_Softmax;
|
||||
|
||||
int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -247,7 +261,8 @@ struct Testbed {
|
||||
tensor_C.reset({options.problem_size.m(), options.problem_size.n()});
|
||||
tensor_D.reset({options.problem_size.m(), options.problem_size.n()});
|
||||
|
||||
tensor_N.reset({options.problem_size.m(), 1});
|
||||
tensor_N.reset({block_num, options.problem_size.m()});
|
||||
tensor_S.reset({block_num, options.problem_size.m()});
|
||||
tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()});
|
||||
|
||||
reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false);
|
||||
@ -342,7 +357,7 @@ struct Testbed {
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_N.host_view(),
|
||||
ElementN()
|
||||
ElementNorm()
|
||||
);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
@ -354,6 +369,7 @@ struct Testbed {
|
||||
tensor_B.sync_device();
|
||||
tensor_D.sync_device();
|
||||
tensor_N.sync_device();
|
||||
tensor_S.sync_device();
|
||||
tensor_Softmax.sync_device();
|
||||
}
|
||||
|
||||
@ -377,6 +393,7 @@ struct Testbed {
|
||||
ElementCompute(options.beta)
|
||||
},
|
||||
tensor_N.device_ref(),
|
||||
tensor_S.device_ref(),
|
||||
tensor_Softmax.device_ref()
|
||||
);
|
||||
|
||||
@ -420,7 +437,7 @@ struct Testbed {
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
reference_N.at({m, 0}) = reference_D.at({m, 0});
|
||||
for (int n = 1; n < options.problem_size.n(); ++n) {
|
||||
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementN(reference_D.at({m, n})));
|
||||
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(reference_D.at({m, n})));
|
||||
}
|
||||
}
|
||||
|
||||
@ -454,6 +471,20 @@ struct Testbed {
|
||||
std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n";
|
||||
}
|
||||
|
||||
bool verify_tensor_N(cutlass::HostTensor<ElementNorm, LayoutC> tensor_N, \
|
||||
cutlass::HostTensor<ElementNorm, LayoutC> reference_N) {
|
||||
|
||||
for (int m = 0; m < options.problem_size.m(); ++m) {
|
||||
float diff = (float)(tensor_N.at({0, m}) - reference_N.at({m, 0}));
|
||||
if (fabs(diff) > options.tolerance) {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Verifies the reference matches
|
||||
bool verify() {
|
||||
|
||||
@ -489,22 +520,7 @@ struct Testbed {
|
||||
}
|
||||
|
||||
if (!verified_N) {
|
||||
|
||||
double norm_diff = cutlass::reference::host::TensorNormDiff(
|
||||
tensor_N.host_view(),
|
||||
reference_N.host_view());
|
||||
|
||||
double norm_reference = cutlass::reference::host::TensorNorm(
|
||||
reference_N.host_view());
|
||||
|
||||
double rel_error = norm_diff / norm_reference;
|
||||
|
||||
if (rel_error > kThreshold) {
|
||||
std::cerr << "\n\nTensor N Relative error: " << rel_error << std::endl;
|
||||
}
|
||||
else {
|
||||
verified_N = true;
|
||||
}
|
||||
verified_N = verify_tensor_N(tensor_N, reference_N);
|
||||
}
|
||||
|
||||
if (!verified_Softmax) {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -71,7 +71,7 @@ public:
|
||||
using ElementCompute = ElementCompute_;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
static const ScaleType::Kind kScale = Scale;
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||
|
@ -718,11 +718,11 @@ double fast_exp(double x) {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
float fast_exp(half_t x) {
|
||||
half_t fast_exp(half_t x) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
|
||||
return ::hexp(x.to_half());
|
||||
return (half_t)(::hexp(x.to_half()));
|
||||
#else
|
||||
return fast_exp(float(x));
|
||||
return (half_t)(fast_exp(float(x)));
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -908,4 +908,3 @@ T absolute_value(T x) {
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user