13 KiB
		
	
	
	
	
	
	
	
			
		
		
	
	Basics of PyCUTLASS
PyCUTLASS handles the following things when launch the CUTLASS kernels
- Memory management
- Operation Description
- Code emission and compilation
- Arguments preprocessing
- Kernel launching
- Result Synchronization
Memory management
PyCUTLASS uses RMM to manage device memory. At the beginning of the program, call
pycutlass.get_memory_pool({init_pool_size_in_bytes}, {max_pool_size_in_bytes})
We also provide functions to query the allocated size.
bytes = get_allocated_size()
Operation Description
PyCUTLASS provides operation description for GEMM, GEMM Grouped and Conv2d operations. These operation descriptions are assembled from four foundamental concepts
- Math Instruction: math instruction executed in GPU cores
- Tile Description: tiling sizes and pipeline stages
- Operand Description: data type, layout, memory alignment
- Epilogue Functor: epilogue function
Math Instruction
The math instruction is defined as follows:
math_inst = MathInstruction(
    {instruction_shape}, {element_a}, {element_b},
    {element_acc}, {opclass}, {math_operation}
)
The {instruction_shape} and {opclass} defines the instruction size and type. The table below lists valid combinations. {element_a}, {element_b} define the source operand data type for each instructions, and {element_acc} defines the accumulator type. The {math_operation} defines the math operation applied.
| Opclass | element_a/element_b | element_acc | instruction_shape | math_operation | 
|---|---|---|---|---|
| cutlass.OpClass.TensorOp | cutlass.float64 | cutlass.float64 | [8, 8, 4] | MathOperation.multiply_add | 
| cutlass.float32 cutlass.tfloat32, cutlass.float16 cutlass.bfloat16 | cutlass.float32 | [16, 8, 8] | MathOperation.multiply_add MathOperation.multiply_add_fast_f32 MathOperation.multiply_add_fast_f16 MathOperation.multiply_add_fast_bf16 | |
| cutlass.float16 | cutlass.float16/cutlass.float32 | [16, 8, 16] | MathOperation.multiply_add | |
| cutlass.bfloat_16 | cutlass.float32 | [16, 8, 16] | MathOperation.multiply_add | |
| cutlass.int8 | cutlass.int32 | [16, 8, 32] | MathOperation.multiply_add_saturate | |
| cutlass.OpClass.Simt | cutlass.float64 | cutlass.float64 | [1, 1, 1] | MathOperation.multiply_add | 
| cutlass.float32 | cutlass.float32 | [1, 1, 1] | MathOperation.multiply_add | 
The cutlass.OpClass.TensorOp indicates that the tensor core is used, while cutlass.OpClass.Simt uses the SIMT Core.
The multiply_add_fast_f32 emulates fast accurate SGEMM kernel which is accelerated
using Ampere Tensor Cores. More details can be found in examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm.
Tile Description
The tile description describes the threadblock and warp tiling sizes, as well as the pipeline stages.
tile_description = TileDescription(
    {threadblock_shape}, {stages}, {warp_count},
    math_inst
)
The {threadblock_shape} is a list of 3 integers [Tile_M, Tile_N, Tile_K] that defines the threadblock tiling size. {stages} defines the number of software pipeline stages (detail). {warp_count} defines the number of warps along M, N, and K dimension. I.e., with {threadblock_shape}=[Tile_M, Tile_N, Tile_K] and {warp_count}=[W_M, W_N, W_K], the warp tile size would be [Tile_M / W_M, Tile_N / W_N, Tile_K / W_K].
Operand Description
The Operand Description defines the data type, layout, and memory alignment of input tensor A, B, and C. The output D shares the same attributes with C. The description is as follows:
A = TensorDescription(
    {element_a}, {layout_a}, {alignment_a}
)
B = TensorDescription(
    {element_b}, {layout_b}, {alignment_b}
)
C = TensorDescription(
    {element_c}, {layout_c}, {alignment_c}
)
The table below lists the supported layout and data types for each operation
| Operation | data type | layout | 
|---|---|---|
| GEMM, GEMM Grouped | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.RowMajor, cutlass.ColumnMajor | 
| cutlass.int8 | cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32 | |
| Conv2d Fprop, Dgrad, Wgrad | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.TensorNHWC | 
| Conv2d Fprop | cutlass.int8 | cutlass.TensorNHWC, cutlass.TensorNC32HW32, cutlass.TensorC32RSK32 | 
Epilogue Functor
The epilogue functor defines the epilogue executed after mainloop. We expose the following epilogue functors.
| Epilogue Functor | Remark | 
|---|---|
| LinearCombination | D=\alpha \times Accm + \beta \times C | 
| LinearCombinationClamp | D=\alpha \times Accm + \beta \times C, Output is clamped to the maximum value of the data type output | 
| FastLinearCombinationClamp | D=\alpha \times Accm + \beta \times C, only used for problem sizeK\le 256for cutlass.int8, with accumulator data typecutlass.int32and epilogue compute data typecutlass.float32 | 
| LinearCombinationGeneric | D  = activation(\alpha \times Accm + \beta \times C), available activations includerelu,leaky_relu,tanh,sigmoid,silu,hardswish, andgelu | 
The epilogue functors can be created as follows
# LinearCombination
epilogue_functor = LinearCombination(
    element_C, alignment_c, element_acc, element_epilogue_compute
)
# LinearCombinationClamp
epilogue_functor = LinearCombinationClamp(
    element_C, alignment_c, element_acc, element_epilogue_compute
)
# FastLinearCombinationClamp
epilogue_functor = FastLinearCombinationClamp(
    element_C, alignment_c
)
# LinearCombinationGeneric
epilogue_functor = LinearCombinationGeneric(
    relu(element_epilogue_compute), element_C, alignment_c, 
    element_acc, element_epilogue_compute
)
We also provides an experimental feature "Epilogue Visitor Tree" for GEMM operation. The details can be found in EpilogueVisitorTree.
GEMM Operation
The GEMM Operation description can be created with
operation = GemmOperationUniversal(
    {compute_capability}, tile_description,
    A, B, C, epilogue_functor, 
    {swizzling_functor}, {visitor}
)
- {compute_capability}is an integer indicates the compute capability of the GPU. For A100, it is 80.
- {swizzling_functor}describes how threadblocks are scheduled on GPU. This is used to improve the L2 Locality (detail). Currently we support- cutlass.{IdentitySwizzle1|IdentitySwizzle2|IdentitySwizzle4|IdentitySwizzle8|BatchedIdentitySwizzle}. The last one is used for batched or array GEMM.
- {visitor}: a bool variable indicates whether the epilogue visitor tree is used.
GEMM Grouped Operation
The GEMM Grouped Operation description can be created with
operation = GemmOperationGrouped(
    compute_capability, tile_description,
    A, B, C, epilogue_functor, 
    swizzling_functor, {precompute_mode}
)
- {precompute_mode}: It could be- SchedulerMode.Hostor- SchedulerMode.Device. See examples/24_gemm_grouped for more details.
Conv2d Operation
The Conv2d Operation description can be created with
operation = Conv2dOperation(
    {conv_kind}, {iterator_algorithm},
    compute_capability, tile_description,
    A, B, C, {stride_support},
    epilogue_functor, swizzling_functor
)
- {conv_kind}defines which convolution is executed. Available options include- fprop,- dgrad, and- wgrad.
- {iterator_algorithm}specifies the iterator algorithm used by the implicit GEMM in convolution. The options are as follows:- analytic: functionally correct in all cases but lower performance
- optimized: optimized for R <= 32, S <= 32 and unity-stride dgrad
- fixed_channels: analytic algorithm optimized for fixed channel count (C == AccessSize)
- few_channels: Analytic algorithm optimized for few channels (C divisible by AccessSize)
 
- {stride_support}: distinguishes among partial specializations that accelerate certain problems where convolution stride is unit.- strided: arbitrary convolution stride
- unity: unit convolution stride
 
Code Emission and Compilation
After implementing the operation description, the related host and device code can be compiled with
import pycutlass
pycutlass.compiler.add_module([operation,])
Several operations can be compiled together. The nvcc at $CUDA_INSTALL_PATH/bin is used by default as the compiler backend. But you can also switch to CUDA Python's nvrtc with
pycutlass.compiler.nvrtc()
We also have an internal compiled artifact manager that caches the compiled kernel in both memory and disk. The compiled_cache.db at your workspace is the database that contains the binary files. You can delete the file if you want to recompile the kernels.
Argument Processing
We provide argument wrapper to convert python tensors to the kernel parameters. Currently it supports torch.Tensor, numpy.ndarray, and cupy.ndarray.
GEMM Arguments
The Gemm arguments can be created with
arguments = GemmArguments(
    operation=operation, problem_size={problem_size},
    A={tensor_A}, B={tensor_B}, C={tensor_C}, D={tensor_D},
    output_op={output_op},
    gemm_mode={gemm_mode},
    split_k_slices={split_k_slices}, batch={batch}
)
- problem_sizeis a- cutlass.gemm.GemmCoord(M, N, K)object that defines- M\times N\times Kmatrix multiplication.
- tensor_X: user-provide tensors.
- output_op: the params for the epilogue functor.
- gemm_mode,- split_k_slices, and- batch:
| gemm_mode | split_k_slices | batch | remark | 
|---|---|---|---|
| cutlass.gemm.Mode.Gemm | number of split-K slices | - | the ordinary GEMM or GEMM with serial split-K | 
| cutlass.gemm.Mode.GemmSplitKParallel | number of split-K slices | - | GEMM Split-K Parallel | 
| cutlass.gemm.Mode.Batched | - | batch size | Batched GEMM | 
| cutlass.gemm.Mode.Array | - | batch size | Array GEMM | 
GEMM Grouped Arguments
The GEMM grouped arguments can be created with
arguments = GemmGroupedArguments(
    operation, {problem_sizes_coord}, {tensor_As}, {tensor_Bs}, {tensor_Cs}, {tensor_Ds},
    output_op=output_op)
)
- problem_size_coordis a list of- cutlass.gemm.GemmCoord(M, N, K)for each problem size.
- tensor_Xsis a list of user-provide tensors.
- output_op: the params of the epilogue functor
Conv2d Arguments
The Conv2d arguments can be created with
arguments = Conv2dArguments(
    operation, {problem_size}, {tensor_A},
    {tensor_B}, {tensor_C}, {tensor_D}, 
    {output_op}, 
    {split_k_mode},
    {split_k_slices}
)
- problem_size: it can be constructed with- problem_size = cutlass.conv.Conv2dProblemSize( cutlass.Tensor4DCoord(N, H, W, C), cutlass.Tensor4DCoord(K, R, S, C), cutlass.Tensor4DCoord(pad[0], pad[1], pad[2], pad[3]), cutlass.MatrixCoord(stride[0], stride[1]), cutlass.MatrixCoord(dilation[0], dilation[1]), cutlass.conv.Mode.cross_correlation, split_k_slices, 1 )
- tensor_Xare user-provide tensors
- output_op: the params of the epilogue functor
- split_k_mode: currently we support- cutlass.conv.SplitKMode.Serialand- cutlass.conv.SplitKMode.Parallel.
- split_k_slice: number of split-k slices
For ordinary conv2d, just use cutlass.conv.SplitKMode.Serial with split_k_slice=1.
Getting output_op
The way to create output_op is listed below
output_op = operation.epilogue_type(*([alpha, beta] + args.activation_args)),
It is a list of arguments start with the scaling factor alpha and beta.
The output_op of EpilogueVisitorTree is slightly different. Please check EpilogueVisitorTree for details.
Kernel Launching
With the arguments and operations, the kernel can be launched simply with
operation.run(arguments)
Sync results
We also provide function to synchronize the kernel execution. If you use numpy, it will also copy the result back to host. To do that, run
arguments.sync()
If you use EpilogueVisitorTree, please call
output_op.sync()
Reduction Kernel behind Parallel Split-K
If you use parallel-split-K in GEMM or Conv2d, an additional reduction kernel is required. Please check examples/40_cutlass_py for detail.
