2019-11-20 08:55:34 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								/***************************************************************************************************
  
						 
					
						
							
								
									
										
										
										
											2021-02-26 22:58:26 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 *  Copyright  ( c )  2017 - 2021 ,  NVIDIA  CORPORATION .   All  rights  reserved . 
							 
						 
					
						
							
								
									
										
										
										
											2019-11-20 08:55:34 +08:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  Redistribution  and  use  in  source  and  binary  forms ,  with  or  without  modification ,  are  permitted 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  provided  that  the  following  conditions  are  met : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *      *  Redistributions  of  source  code  must  retain  the  above  copyright  notice ,  this  list  of 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *        conditions  and  the  following  disclaimer . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *      *  Redistributions  in  binary  form  must  reproduce  the  above  copyright  notice ,  this  list  of 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *        conditions  and  the  following  disclaimer  in  the  documentation  and / or  other  materials 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *        provided  with  the  distribution . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *      *  Neither  the  name  of  the  NVIDIA  CORPORATION  nor  the  names  of  its  contributors  may  be  used 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *        to  endorse  or  promote  products  derived  from  this  software  without  specific  prior  written 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *        permission . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  THIS  SOFTWARE  IS  PROVIDED  BY  THE  COPYRIGHT  HOLDERS  AND  CONTRIBUTORS  " AS IS "  AND  ANY  EXPRESS  OR 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  IMPLIED  WARRANTIES ,  INCLUDING ,  BUT  NOT  LIMITED  TO ,  THE  IMPLIED  WARRANTIES  OF  MERCHANTABILITY  AND 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  FITNESS  FOR  A  PARTICULAR  PURPOSE  ARE  DISCLAIMED .  IN  NO  EVENT  SHALL  NVIDIA  CORPORATION  BE  LIABLE 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  FOR  ANY  DIRECT ,  INDIRECT ,  INCIDENTAL ,  SPECIAL ,  EXEMPLARY ,  OR  CONSEQUENTIAL  DAMAGES  ( INCLUDING , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  BUT  NOT  LIMITED  TO ,  PROCUREMENT  OF  SUBSTITUTE  GOODS  OR  SERVICES ;  LOSS  OF  USE ,  DATA ,  OR  PROFITS ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  OR  BUSINESS  INTERRUPTION )  HOWEVER  CAUSED  AND  ON  ANY  THEORY  OF  LIABILITY ,  WHETHER  IN  CONTRACT , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  STRICT  LIABILITY ,  OR  TOR  ( INCLUDING  NEGLIGENCE  OR  OTHERWISE )  ARISING  IN  ANY  WAY  OUT  OF  THE  USE 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  OF  THIS  SOFTWARE ,  EVEN  IF  ADVISED  OF  THE  POSSIBILITY  OF  SUCH  DAMAGE . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								/*! \file
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    \ brief  Unit  tests  for  thread - level  GEMM 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								*/  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# pragma once 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  <iostream> 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "cutlass/gemm/thread/mma.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "../kernel/thread/testbed_kernel.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "cutlass/util/host_tensor.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "cutlass/util/tensor_view_io.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "cutlass/util/reference/host/tensor_copy.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "cutlass/util/reference/host/tensor_fill.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "cutlass/util/reference/host/tensor_compare.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "cutlass/util/reference/host/gemm.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  <cuda.h> 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  <nvrtc.h> 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  "../cutlass/nvrtc/environment.h" 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  <assert.h> 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								/////////////////////////////////////////////////////////////////////////////////////////////////
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								namespace  test  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								namespace  nvrtc  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								namespace  thread  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								/// Structure to compute the matrix product
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								template  <  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Size of the Gemm problem - concept: gemm::GemmShape<>
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  typename  Shape , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Data type of A elements
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  typename  ElementA , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Layout of A matrix (concept: MatrixLayout)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  typename  LayoutA , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Data type of B elements
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  typename  ElementB , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Layout of B matrix (concept: MatrixLayout)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  typename  LayoutB , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Element type of C matrix
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  typename  ElementC , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Layout of C matrix (concept: MatrixLayout)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  typename  LayoutC 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								>  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								struct  Testbed  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Thread-level matrix multiply-accumulate operator
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  using  Mma  =  cutlass : : gemm : : thread : : Mma < 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    Shape , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ElementA , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    LayoutA , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ElementB , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    LayoutB , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ElementC , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    LayoutC 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  > ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  // Data members
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cutlass : : HostTensor < ElementA ,  LayoutA >  tensor_A ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cutlass : : HostTensor < ElementB ,  LayoutB >  tensor_B ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cutlass : : HostTensor < ElementC ,  LayoutC >  tensor_C ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cutlass : : HostTensor < ElementC ,  LayoutC >  tensor_D_computed ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  cutlass : : HostTensor < ElementC ,  LayoutC >  tensor_D_reference ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  // Methods
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Allocates workspace in device memory
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  Testbed ( )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_A . reset ( cutlass : : make_Coord ( Shape : : kM ,  Shape : : kK ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_B . reset ( cutlass : : make_Coord ( Shape : : kK ,  Shape : : kN ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_C . reset ( cutlass : : make_Coord ( Shape : : kM ,  Shape : : kN ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_D_computed . reset ( cutlass : : make_Coord ( Shape : : kM ,  Shape : : kN ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_D_reference . reset ( cutlass : : make_Coord ( Shape : : kM ,  Shape : : kN ) ,  false ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  static  inline  bool  check_nvrtc_error ( nvrtcResult  error )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( error  ! =  NVRTC_SUCCESS )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      std : : cerr  < <  " failed to compile  " ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  /// Runs the test
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  bool  run ( std : : string  const  & gemm_traits )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // initialize device memory
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cutlass : : reference : : host : : BlockFillSequential ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_A . host_data ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_A . capacity ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cutlass : : reference : : host : : BlockFillSequential ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_B . host_data ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_B . capacity ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ElementB ( 1 ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ElementB ( 2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cutlass : : reference : : host : : TensorFill ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_C . host_view ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ElementC ( 0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cutlass : : reference : : host : : TensorFill ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_D_computed . host_view ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ElementC ( 0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cutlass : : reference : : host : : TensorFill ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_D_reference . host_view ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ElementC ( 0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_A . sync_device ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_B . sync_device ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_C . sync_device ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_D_computed . sync_device ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#if 0 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // launch kernel
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cutlass : : gemm : : kernel : : testbed_kernel < Mma > < < <  dim3 ( 1 ,  1 ) ,  dim3 ( 1 ,  1 ,  1 )  > > > ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_D_computed . device_data ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_A . device_data ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_B . device_data ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_C . device_data ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# else 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // Instantiate gemm_kernel
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    nvrtcResult  result_nvrtc ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    nvrtcProgram  program ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    static  char  const  * src  = 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " #include  \" cutlass/gemm/thread/mma.h \" \n " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " #include  \" cutlass/gemm/gemm.h \" \n " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " #include  \" cutlass/layout/matrix.h \" \n " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " #include  \" unit/nvrtc/kernel/thread/testbed_kernel.h \" \n " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    std : : string  type_name ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#if 0 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // As altername solution we might want to implement to_string<GemmTraits>() to get the traits string.
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    nvrtcGetTypeName < typename  GemmTraits_ > ( & type_name ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# else 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    type_name  =  gemm_traits ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# endif 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    result_nvrtc  =  nvrtcCreateProgram ( & program , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                                    src , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                                    NULL , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                                    ( int ) cutlass : : nvrtc : : kCutlassHeaderCount , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                                    cutlass : : nvrtc : : kCutlassHeaders , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                                    cutlass : : nvrtc : : kCutlassHeaderNames ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    check_nvrtc_error ( result_nvrtc ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    std : : string  gemm_kernel_instantiation  = 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " test::nvrtc::kernel::thread::testbed_kernel<  "  +  type_name  +  "  > " ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    nvrtcAddNameExpression ( program ,  gemm_kernel_instantiation . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    const  char  * opts [ ]  =  { " --gpu-architecture=compute_75 " , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                          " --std=c++11 " , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                          " --include-path=/usr/local/cuda-10.1/include " } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    result_nvrtc  =  nvrtcCompileProgram ( program ,  3 ,  opts ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( result_nvrtc  ! =  NVRTC_SUCCESS )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      size_t  logSize ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      nvrtcGetProgramLogSize ( program ,  & logSize ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      std : : vector < char >  log ( logSize ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      nvrtcGetProgramLog ( program ,  log . data ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      std : : cout  < <  " Compile log: "  < <  std : : endl  < <  log . data ( )  < <  std : : endl ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( ! check_nvrtc_error ( result_nvrtc ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // The lowered name is the name of the template instantiation in the generated PTX code.
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    char  const  * gemm_kernel_lowered_name ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    nvrtcGetLoweredName ( program ,  gemm_kernel_instantiation . c_str ( ) ,  & gemm_kernel_lowered_name ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( ! check_nvrtc_error ( result_nvrtc ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // Query the size of the genereated PTX so that we can allocate storage and retrieve it afterwards
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    size_t  ptx_size ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    result_nvrtc  =  nvrtcGetPTXSize ( program ,  & ptx_size ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( ! check_nvrtc_error ( result_nvrtc ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    std : : vector < char >  ptx ( ptx_size ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    result_nvrtc  =  nvrtcGetPTX ( program ,  ptx . data ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( ! check_nvrtc_error ( result_nvrtc ) )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // we do not need the nvrtc program anymore
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    //nvrtcDestroyProgram(&program);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    CUmodule  module ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    CUresult  result_cuda ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    result_cuda  =  cuModuleLoadDataEx ( & module ,  ptx . data ( ) ,  0 ,  0 ,  0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( result_cuda  ! =  CUDA_SUCCESS )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    CUfunction  kernel ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    result_cuda  =  cuModuleGetFunction ( & kernel ,  module ,  gemm_kernel_lowered_name ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( result_cuda  ! =  CUDA_SUCCESS )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    void *  d_a  =  ( void * ) tensor_A . device_data ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    void *  d_b  =  ( void * ) tensor_B . device_data ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    void *  d_c  =  ( void * ) tensor_C . device_data ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    void *  d_d  =  ( void * ) tensor_D_computed . device_data ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    void *  args [ ]  =  {  & d_d ,  & d_a ,  & d_b ,  & d_c  } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // CUfunction f, unsigned int  gridDimX, unsigned int  gridDimY, unsigned int  gridDimZ, unsigned int  blockDimX, unsigned int  blockDimY, unsigned int  blockDimZ, unsigned int  sharedMemBytes, CUstream hStream, void** kernelParams, void** extra
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    result_cuda  =  cuLaunchKernel ( kernel ,  1 ,  1 ,  1 ,  1 ,  1 ,  1 ,  0 ,  0  /*cudaStreamDefault*/ ,  args ,  0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( result_cuda  ! =  CUDA_SUCCESS )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      assert ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }  else  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# endif 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // verify no errors
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cudaError_t  result  =  cudaDeviceSynchronize ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( result  ! =  cudaSuccess )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      std : : cout  < <  " CUDA ERROR:  "  < <  cudaGetErrorString ( result ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      return  false ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    tensor_D_computed . sync_host ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // Reference implementation
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    //tensor_D_reference.fill(tensor_C.host_view());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cutlass : : reference : : host : : Gemm < ElementA ,  LayoutA ,  ElementB ,  LayoutB , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                                   ElementC ,  LayoutC ,  ElementC ,  ElementC >  reference_gemm ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    reference_gemm ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      { Shape : : kM ,  Shape : : kN ,  Shape : : kK } , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ElementC ( 1 ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_A . host_ref ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_B . host_ref ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ElementC ( 0 ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_D_reference . host_ref ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // Verify equivalence
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    //
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // compare
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    bool  passed  =  cutlass : : reference : : host : : TensorEquals ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_D_computed . host_view ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      tensor_D_reference . host_view ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if ( ! passed )  std : : cout 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      < <  " A: \n "  < <  tensor_A . host_view ( )  < <  " \n \n " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      < <  " B: \n "  < <  tensor_B . host_view ( )  < <  " \n \n " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      < <  " C: \n "  < <  tensor_C . host_view ( )  < <  " \n \n " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      < <  " Reference: \n "  < <  tensor_D_reference . host_view ( )  < <  " \n \n " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      < <  " Computed: \n "  < <  tensor_D_computed . host_view ( )  < <  std : : endl ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    std : : cout  < <  " passed  "  < <  passed  < <  std : : endl ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  passed ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								} ;  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								/////////////////////////////////////////////////////////////////////////////////////////////////
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  // namespace thread
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  // namespace nvrtc
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  // namespace test