322 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			322 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| /*! \file
 | |
|     \brief Unit tests for thread-level GEMM
 | |
| */
 | |
| 
 | |
| #include "../../common/cutlass_unit_test.h"
 | |
| 
 | |
| #include "cutlass/layout/layout.h"
 | |
| #include "cutlass/epilogue/thread/activation.h"
 | |
| 
 | |
| #include "cutlass/util/host_tensor.h"
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| template <typename T, int N, typename Func>
 | |
| __global__ void test_Epilogue_thread_activation(T *out, T *in) {
 | |
| 
 | |
|   cutlass::Array<T, N> *vec_out = reinterpret_cast<cutlass::Array<T, N> *>(out);
 | |
|   cutlass::Array<T, N> *vec_in = reinterpret_cast<cutlass::Array<T, N> *>(in);
 | |
| 
 | |
|   Func func;
 | |
|   vec_out[threadIdx.x] = func(vec_in[threadIdx.x]);
 | |
| }
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| //
 | |
| // Reference
 | |
| //
 | |
| 
 | |
| static double GELU_golden_input[] = {
 | |
|     1.587425827980,  1.157652974129,  0.750432848930, -0.965980410576,
 | |
|     -0.388184845448,  0.014422321692,  0.353164494038,  1.354383468628,
 | |
|      0.167588576674,  0.272798538208, -0.377032428980,  1.923444747925,
 | |
|      0.308164477348, -0.341318070889,  0.278338819742, -0.292668998241,
 | |
|     -1.051743745804, -0.814175724983,  0.112737402320,  1.262938618660,
 | |
|     -1.582363605499,  0.722016870975,  1.053453564644, -0.659764587879,
 | |
|      0.734917521477,  0.091274201870,  0.604461073875, -0.219043627381,
 | |
|     -0.136795744300,  0.960650205612, -1.805408835411,  0.091029644012,
 | |
|     -1.023343324661,  0.147713735700, -0.499895423651,  1.351878166199,
 | |
|     -1.631091356277, -0.336171895266, -1.612408638000,  0.090832948685,
 | |
|     -0.658132910728, -0.326727777719, -1.986387014389,  0.787685871124,
 | |
|     -1.015677452087, -0.225094825029,  0.876752018929,  0.744826257229,
 | |
|      0.870290279388, -0.757595360279,  1.510331749916,  0.750012576580,
 | |
|      0.906444966793, -0.915759027004,  1.260277032852, -0.158465340734,
 | |
|     -0.109191477299, -0.817102134228,  0.391305118799, -0.524910449982,
 | |
|      0.351349592209,  0.801979541779,  0.446691334248, -0.741077482700,
 | |
|      1.205966711044, -0.910210072994,  0.945986449718,  0.784096539021,
 | |
|      1.670521497726,  0.344931513071, -0.301411420107,  0.309870749712,
 | |
|     -0.879704594612, -1.951189517975, -0.805817663670, -0.661812782288,
 | |
|     -0.505914270878, -1.836273789406, -0.381845980883, -0.554707705975,
 | |
|     -0.375447630882, -0.516645610332,  0.509586095810,  1.087131023407,
 | |
|      2.664817094803, -1.558295488358, -0.076461032033, -0.504621028900,
 | |
|      1.327111959457, -1.819981694221,  1.350415468216, -2.074112653732,
 | |
|      1.501431345940, -1.339013576508,  0.162817999721, -1.473457217216,
 | |
|      0.357770472765,  0.188413277268,  1.601302266121, -0.653882205486,
 | |
|      0.856162548065,  0.763102591038, -0.526283502579,  0.581961452961,
 | |
|      0.089969776571,  1.968745589256,  0.545802056789, -1.168786048889,
 | |
|      1.206663012505, -0.109096683562, -1.223938226700,  0.744599223137,
 | |
|     -1.779406785965,  0.766436159611, -0.579044401646, -1.002057313919,
 | |
|     -0.715845823288, -0.562508940697,  0.886768460274,  2.327786445618,
 | |
|     -0.148763969541, -0.918884515762, -0.367678701878, -1.105021238327,
 | |
|     -0.461237311363,  0.158228352666, -0.254040330648,  1.427477598190,
 | |
|      0.277530491352,  0.046293262392, -0.535557329655, -1.486695051193,
 | |
|     -0.953706681728, -1.040495038033, -0.314667612314,  0.348172843456,
 | |
|      0.522773325443,  0.025960063562, -0.482472360134,  1.993084549904,
 | |
|     -0.253064930439, -0.012146313675, -2.166327714920,  0.398040622473,
 | |
|     -0.022238900885, -0.443580865860, -0.898376941681, -0.571689844131,
 | |
|      1.666979670525, -0.831176340580, -0.671057403088,  0.481970995665,
 | |
|     -1.096243023872, -1.493894338608,  0.596651911736, -0.229505166411,
 | |
|      1.165976166725,  0.905094027519,  0.049716457725, -1.362933635712,
 | |
|     -0.366948783398,  1.461613893509, -0.718411505222,  0.895385026932,
 | |
|     -0.763122260571,  1.329716682434,  1.366570711136, -0.086544901133,
 | |
|      0.059739742428,  0.940766513348, -0.272854357958, -1.738811373711,
 | |
|     -0.361239165068,  0.696977972984,  1.288442254066,  1.264815807343,
 | |
|     -0.573566436768, -1.141678214073,  0.081865988672, -0.886228799820,
 | |
|     -0.236933603883,  1.050115466118, -0.538952171803,  0.651773929596,
 | |
|     -0.220034509897, -1.198960781097,  1.247478365898, -0.053529661149,
 | |
|      0.639809548855,  1.672434806824,  0.511088073254, -1.179364681244,
 | |
|     -0.730427742004,  0.157630980015,  0.389369845390, -0.925578773022,
 | |
|     -0.093250080943, -0.391062080860,  0.852983593941,  1.868778109550,
 | |
|     -1.198786258698,  0.604997038841, -1.482687234879, -2.469333171844,
 | |
|      0.718807697296, -0.559609353542,  2.187228441238, -2.927527904510,
 | |
|      0.148535788059, -0.097280368209,  0.674131810665, -1.137645959854,
 | |
|      0.792729616165, -1.166317462921, -0.498791724443,  1.675866723061,
 | |
|     -0.137909621000, -0.653263568878, -2.281216144562,  0.296096831560,
 | |
|      2.002410173416,  1.083609819412,  0.933580815792, -1.504760265350,
 | |
|      2.185185909271,  0.286121010780, -1.035485863686, -0.216372340918,
 | |
|     -0.274334043264, -0.849510788918, -1.397169828415, -0.407644748688,
 | |
|      0.159476816654, -0.170650705695,  0.335193097591, -0.156852483749,
 | |
|      0.036168430001,  0.858105242252, -1.086121797562,  0.404813349247,
 | |
|     -0.481496721506, -0.389882832766,  0.020690204576, -0.772020936012,
 | |
|     -0.758921504021,  0.323482036591,  0.115715265274, -0.811228036880,
 | |
|     -0.882436633110,  0.176811277866,  1.678015947342,  0.379081040621,
 | |
|     -0.842976212502,  0.346952259541, -0.545828759670,  1.632800459862
 | |
| };
 | |
| 
 | |
| static double GELU_golden_output[] = {
 | |
|     1.498199582100,  1.014679551125,  0.580462038517, -0.161344811320,
 | |
|     -0.135453075171,  0.007294139825,  0.225325092673,  1.235459089279,
 | |
|      0.094946734607,  0.165724009275, -0.133120641112,  1.871103763580,
 | |
|      0.191376730800, -0.125069886446,  0.169681981206, -0.112644664943,
 | |
|     -0.154036879539, -0.169163048267,  0.061428427696,  1.132469892502,
 | |
|     -0.089851818979,  0.552240371704,  0.899579226971, -0.168043658137,
 | |
|      0.565008401871,  0.048956073821,  0.439583092928, -0.090532489121,
 | |
|     -0.060955654830,  0.798911273479, -0.064101703465,  0.048816055059,
 | |
|     -0.156645998359,  0.082529976964, -0.154254898429,  1.232632875443,
 | |
|     -0.083896033466, -0.123835846782, -0.086161509156,  0.048703473061,
 | |
|     -0.167972877622, -0.121522113681, -0.046670529991,  0.617986679077,
 | |
|     -0.157319813967, -0.092503339052,  0.709896743298,  0.574865520000,
 | |
|      0.703132867813, -0.169963955879,  1.411436080933,  0.580042064190,
 | |
|      0.741154611111, -0.164741978049,  1.129479527473, -0.069256491959,
 | |
|     -0.049848672003, -0.169087052345,  0.255214750767, -0.157380074263,
 | |
|      0.223928079009,  0.632535398006,  0.300378054380, -0.169946283102,
 | |
|      1.068588852882, -0.165071934462,  0.783203184605,  0.614346146584,
 | |
|      1.591325283051,  0.219006344676, -0.115003645420,  0.192637458444,
 | |
|     -0.166712537408, -0.049788996577, -0.169361919165, -0.168130636215,
 | |
|     -0.155041679740, -0.060888241976, -0.134137839079, -0.160614117980,
 | |
|     -0.132782235742, -0.156389534473,  0.354075312614,  0.936574816704,
 | |
|      2.654553413391, -0.092845752835, -0.035900454968, -0.154874503613,
 | |
|      1.204704761505, -0.062572605908,  1.230982899666, -0.039479542524,
 | |
|      1.401402950287, -0.120890334249,  0.091938301921, -0.103604510427,
 | |
|      0.228880971670,  0.108285568655,  1.513783097267, -0.167782157660,
 | |
|      0.688394129276,  0.593158841133, -0.157540664077,  0.418839782476,
 | |
|      0.048209801316,  1.920528769493,  0.386099845171, -0.141709372401,
 | |
|      1.069367766380, -0.049809500575, -0.135230198503,  0.574639260769,
 | |
|     -0.066881760955,  0.596510827541, -0.162873372436, -0.158483341336,
 | |
|     -0.169686436653, -0.161375194788,  0.720409095287,  2.304597616196,
 | |
|     -0.065585561097, -0.164551988244, -0.131098195910, -0.148708447814,
 | |
|     -0.148663327098,  0.089060656726, -0.101548098028,  1.317959904671,
 | |
|      0.169103100896,  0.024001283571, -0.158595800400, -0.101909510791,
 | |
|     -0.162240833044, -0.155090972781, -0.118474565446,  0.221488356590,
 | |
|      0.365645468235,  0.013248858973, -0.151851043105,  1.946992278099,
 | |
|     -0.101253561676, -0.006014300976, -0.032804865390,  0.260597169399,
 | |
|     -0.010922161862, -0.145792976022, -0.165743649006, -0.162226170301,
 | |
|      1.587365984917, -0.168676435947, -0.168497130275,  0.330191940069,
 | |
|     -0.149622067809, -0.100989677012,  0.432351946831, -0.093922272325,
 | |
|      1.023946166039,  0.739726305008,  0.025843897834, -0.117827951908,
 | |
|     -0.130937814713,  1.356489539146, -0.169726014137,  0.729478538036,
 | |
|     -0.169943705201,  1.207641005516,  1.249209761620, -0.040288090706,
 | |
|      0.031292784959,  0.777626037598, -0.107090584934, -0.071350336075,
 | |
|     -0.129670530558,  0.527676224709,  1.161149263382,  1.134579420090,
 | |
|     -0.162394225597, -0.144757837057,  0.043603736907, -0.166386902332,
 | |
|     -0.096278958023,  0.895924389362, -0.158969298005,  0.484089732170,
 | |
|     -0.090857118368, -0.138206124306,  1.115107178688, -0.025622237474,
 | |
|      0.472724437714,  1.593463659286,  0.355387806892, -0.140493586659,
 | |
|     -0.169871479273,  0.088687323034,  0.253673940897, -0.164135158062,
 | |
|     -0.043161027133, -0.136040985584,  0.685087263584,  1.811169505119,
 | |
|     -0.138226687908,  0.440080583096, -0.102422207594, -0.016713079065,
 | |
|      0.549075841904, -0.161096408963,  2.155813455582, -0.005001218989,
 | |
|      0.083037458360, -0.044870752841,  0.505522191525, -0.145202502608,
 | |
|      0.623111069202, -0.141991063952, -0.154108211398,  1.597298502922,
 | |
|     -0.061391282827, -0.167753636837, -0.025704355910,  0.182520583272,
 | |
|      1.957115054131,  0.932696640491,  0.769961357117, -0.099604383111,
 | |
|      2.153636932373,  0.175279796124, -0.155551761389, -0.089653611183,
 | |
|     -0.107515335083, -0.168032020330, -0.113423995674, -0.139319628477,
 | |
|      0.089841812849, -0.073763631284,  0.211594089866, -0.068651281297,
 | |
|      0.018605981022,  0.690416753292, -0.150658726692,  0.266040354967,
 | |
|     -0.151710823178, -0.135800719261,  0.010515870526, -0.169883996248,
 | |
|     -0.169960290194,  0.202769815922,  0.063187584281, -0.169236257672,
 | |
|     -0.166577890515,  0.100812792778,  1.599699616432,  0.245525524020,
 | |
|     -0.168275654316,  0.220552831888, -0.159705042839,  1.549110531807
 | |
| };
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| TEST(Epilogue_thread_gelu_taylor, device_f32) {
 | |
| 
 | |
|     int const kN = 256;
 | |
|     int const kV = 4;
 | |
| 
 | |
|     using Element = float;
 | |
|     using Func = cutlass::epilogue::thread::GELU_taylor<cutlass::Array<Element, kV>>;
 | |
| 
 | |
|     double tolerance = 0.005;
 | |
|     
 | |
|     //
 | |
|     // Construct workspace
 | |
|     //
 | |
|     cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Destination({1, kN});
 | |
|     cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Source({1, kN});
 | |
| 
 | |
|     for (int i = 0; i < kN; ++i) {
 | |
|         tensor_Source.host_data(i) = Element(GELU_golden_input[i]);
 | |
|     }
 | |
| 
 | |
|     tensor_Destination.sync_device();
 | |
|     tensor_Source.sync_device();
 | |
| 
 | |
|     //
 | |
|     // Launch the kernel
 | |
|     //
 | |
|     dim3 grid(1,1,1);
 | |
|     dim3 block(kN / kV, 1, 1);
 | |
| 
 | |
|     test_Epilogue_thread_activation<Element, kV, Func><<< grid, block >>>(
 | |
|         tensor_Destination.device_data(),
 | |
|         tensor_Source.device_data());
 | |
| 
 | |
|     tensor_Destination.sync_host();
 | |
| 
 | |
|     //
 | |
|     // Verify
 | |
|     //
 | |
| 
 | |
|     for (int i = 0; i < kN; ++i) {
 | |
|         Element input = Element(GELU_golden_input[i]);
 | |
|         Element got = tensor_Destination.host_data(i);
 | |
|         Element expected = Element(GELU_golden_output[i]);
 | |
| 
 | |
|         double rel_error = (double(got) - double(expected)) / double(expected);
 | |
| 
 | |
|         double tolerance_override = tolerance;
 | |
| 
 | |
|         switch (i) {
 | |
|             case 142: tolerance_override = 0.008; break;
 | |
|             case 203: tolerance_override = 0.03; break;
 | |
|             case 207: tolerance_override = 0.09; break;
 | |
|             case 218: tolerance_override = 0.013; break;
 | |
|         }
 | |
| 
 | |
|         EXPECT_LT(std::abs(rel_error), tolerance_override) 
 | |
|             << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected;
 | |
|     }
 | |
| }
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| TEST(Epilogue_thread_gelu_taylor, device_f16) {
 | |
| 
 | |
|     int const kN = 256;
 | |
|     int const kV = 8;
 | |
| 
 | |
|     using Element = cutlass::half_t;
 | |
|     using Func = cutlass::epilogue::thread::GELU_taylor<cutlass::Array<Element, kV>>;
 | |
| 
 | |
|     double tolerance = 0.005;
 | |
| 
 | |
|     //
 | |
|     // Construct workspace
 | |
|     //
 | |
|     cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Destination({1, kN});
 | |
|     cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Source({1, kN});
 | |
| 
 | |
|     for (int i = 0; i < kN; ++i) {
 | |
|         tensor_Source.host_data(i) = Element(GELU_golden_input[i]);
 | |
|     }
 | |
| 
 | |
|     tensor_Destination.sync_device();
 | |
|     tensor_Source.sync_device();
 | |
| 
 | |
|     //
 | |
|     // Launch the kernel
 | |
|     //
 | |
|     dim3 grid(1,1,1);
 | |
|     dim3 block(kN / kV, 1, 1);
 | |
| 
 | |
|     test_Epilogue_thread_activation<Element, kV, Func><<< grid, block >>>(
 | |
|         tensor_Destination.device_data(),
 | |
|         tensor_Source.device_data());
 | |
| 
 | |
|     tensor_Destination.sync_host();
 | |
| 
 | |
|     //
 | |
|     // Verify
 | |
|     //
 | |
| 
 | |
|     for (int i = 0; i < kN; ++i) {
 | |
|         Element input = Element(GELU_golden_input[i]);
 | |
|         Element got = tensor_Destination.host_data(i);
 | |
|         Element expected = Element(GELU_golden_output[i]);
 | |
| 
 | |
|         double rel_error = (double(got) - double(expected)) / double(expected);
 | |
|         
 | |
|         double tolerance_override = tolerance;
 | |
| 
 | |
|         switch (i) {
 | |
|             case 36: tolerance_override = 0.006; break;
 | |
|             case 77: tolerance_override = 0.009; break;
 | |
|             case 95: tolerance_override = 0.008; break;
 | |
|             case 112: tolerance_override = 0.007; break;
 | |
|             case 171: tolerance_override = 0.006; break;
 | |
|             case 203: tolerance_override = 0.03; break;
 | |
|             case 207: tolerance_override = 0.15; break;
 | |
|         }
 | |
| 
 | |
|         EXPECT_LT(std::abs(rel_error), tolerance_override) 
 | |
|             << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected;
 | |
|     }
 | |
| }
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
