parent
c008b4aea8
commit
557be3ab0e
@ -1677,7 +1677,7 @@ template<typename Element , typename Layout > </div>
|
|||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
</div><div class="memdoc">
|
</div><div class="memdoc">
|
||||||
<p>Returns a pair containing a boolean of whether a value exists in a tensor and the location of of the first occurrence. If the value is not contained in the tensor, the second element of the pair is undefined. </p>
|
<p>Returns a pair containing a boolean of whether a value exists in a tensor and the location of the first occurrence. If the value is not contained in the tensor, the second element of the pair is undefined. </p>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -49,7 +49,7 @@ class gen_device:
|
|||||||
self.arg_member = []
|
self.arg_member = []
|
||||||
self.gen_class_name = gen_class_name
|
self.gen_class_name = gen_class_name
|
||||||
self.gen_kernel_name = gen_class_name + "Kernel"
|
self.gen_kernel_name = gen_class_name + "Kernel"
|
||||||
self.tempalte_args = []
|
self.template_args = []
|
||||||
self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int}
|
self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int}
|
||||||
|
|
||||||
self.file_name = output_dir + "/device/" +gen_class_name +".h"
|
self.file_name = output_dir + "/device/" +gen_class_name +".h"
|
||||||
@ -63,7 +63,7 @@ class gen_device:
|
|||||||
self.first_use_1stage = False
|
self.first_use_1stage = False
|
||||||
|
|
||||||
## gen kernel
|
## gen kernel
|
||||||
self.gen_kernel = gen_ker.gen_kernel(self.tempalte_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root)
|
self.gen_kernel = gen_ker.gen_kernel(self.template_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root)
|
||||||
|
|
||||||
|
|
||||||
def __check_arg_type(self, temp_arg):
|
def __check_arg_type(self, temp_arg):
|
||||||
@ -126,7 +126,7 @@ class gen_device:
|
|||||||
func_code = self.gen_all_func()
|
func_code = self.gen_all_func()
|
||||||
member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n"
|
member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n"
|
||||||
|
|
||||||
gen_code = gen_ir.gen_template_class(self.gen_class_name, self.tempalte_args, func_code + member_var_code)
|
gen_code = gen_ir.gen_template_class(self.gen_class_name, self.template_args, func_code + member_var_code)
|
||||||
code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code)))
|
code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code)))
|
||||||
|
|
||||||
if ifprint:
|
if ifprint:
|
||||||
@ -142,7 +142,7 @@ class gen_device:
|
|||||||
|
|
||||||
def update_b2b_class_template_args(self):
|
def update_b2b_class_template_args(self):
|
||||||
for arg in self.args.keys():
|
for arg in self.args.keys():
|
||||||
self.tempalte_args.append([self.__check_arg_type(arg), arg, self.args[arg]])
|
self.template_args.append([self.__check_arg_type(arg), arg, self.args[arg]])
|
||||||
|
|
||||||
def update_b2b_args(self):
|
def update_b2b_args(self):
|
||||||
|
|
||||||
|
@ -444,7 +444,7 @@ class gen_kernel:
|
|||||||
|
|
||||||
self.gen_class_name = "B2bGemm"
|
self.gen_class_name = "B2bGemm"
|
||||||
self.gen_kernel_name = gen_class_name + "Kernel"
|
self.gen_kernel_name = gen_class_name + "Kernel"
|
||||||
self.tempalte_args = []
|
self.template_args = []
|
||||||
|
|
||||||
self.cutlass_deps_root = cutlass_deps_root
|
self.cutlass_deps_root = cutlass_deps_root
|
||||||
self.project_root = project_root
|
self.project_root = project_root
|
||||||
|
@ -957,13 +957,13 @@ public:\n\
|
|||||||
|
|
||||||
def gen_code(self):
|
def gen_code(self):
|
||||||
|
|
||||||
tempalte_arg = []
|
template_arg = []
|
||||||
for i in range(self.b2b_num):
|
for i in range(self.b2b_num):
|
||||||
tempalte_arg.append(("typename", helper.var_idx("Shape", i)))
|
template_arg.append(("typename", helper.var_idx("Shape", i)))
|
||||||
for i in range(self.b2b_num):
|
for i in range(self.b2b_num):
|
||||||
tempalte_arg.append(("typename", helper.var_idx("Policy", i)))
|
template_arg.append(("typename", helper.var_idx("Policy", i)))
|
||||||
for i in range(self.b2b_num):
|
for i in range(self.b2b_num):
|
||||||
tempalte_arg.append((int, helper.var_idx("Stage", i)))
|
template_arg.append((int, helper.var_idx("Stage", i)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -971,7 +971,7 @@ public:\n\
|
|||||||
code_body += self.gen_protected()
|
code_body += self.gen_protected()
|
||||||
code_body += self.gen_public_member()
|
code_body += self.gen_public_member()
|
||||||
|
|
||||||
class_code = gen_ir.gen_template_class("B2bMmaBase", tempalte_arg, code_body)
|
class_code = gen_ir.gen_template_class("B2bMmaBase", template_arg, code_body)
|
||||||
|
|
||||||
code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code)))
|
code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code)))
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Define the epilogue visitor functor\n",
|
"## Define the epilogue visitor functor\n",
|
||||||
"The epilogue functor can be defined as a simple Python function and a set of example tensors for inputs and outputs. The example below illustrates a complex epilogue under the directed acyclic graph structure (`F` is used twice). The epilogue takes source tensors in different ranks: `alpha`, `beta` are scalars, `bias` is a column vector to broadcast, and `C`, `aux` are matrices. It contains various math operations from basic arithmatic operations and built-in callable functions like `relu`. It also accomodates multiple outputs `D` and `F`. Note that there are some restrictions on syntax.\n",
|
"The epilogue functor can be defined as a simple Python function and a set of example tensors for inputs and outputs. The example below illustrates a complex epilogue under the directed acyclic graph structure (`F` is used twice). The epilogue takes source tensors in different ranks: `alpha`, `beta` are scalars, `bias` is a column vector to broadcast, and `C`, `aux` are matrices. It contains various math operations from basic arithmatic operations and built-in callable functions like `relu`. It also accomodates multiple outputs `D` and `F`. Note that there are some restrictions on syntax.\n",
|
||||||
"* Each named variable must be assigned exactly once and defined before it it used.\n",
|
"* Each named variable must be assigned exactly once and defined before it used.\n",
|
||||||
"* Reserved names: `accum`, `C`, and `D` are reserved for accumulator, tensor_C, and tensor_D.\n",
|
"* Reserved names: `accum`, `C`, and `D` are reserved for accumulator, tensor_C, and tensor_D.\n",
|
||||||
"* Return values must be a named variable.\n",
|
"* Return values must be a named variable.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -123,7 +123,7 @@ transform(Tensor<EngineIn,LayoutIn>&& tensor_in, Tensor<EngineOut,LayoutOut>&& t
|
|||||||
|
|
||||||
// Similar to std::transform with a binary operation
|
// Similar to std::transform with a binary operation
|
||||||
// Takes two tensors as input and one tensor as output.
|
// Takes two tensors as input and one tensor as output.
|
||||||
// Applies the binary_op to tensor_in1 and and tensor_in2 and
|
// Applies the binary_op to tensor_in1 and tensor_in2 and
|
||||||
// assigns it to tensor_out
|
// assigns it to tensor_out
|
||||||
template <class EngineIn1, class LayoutIn1,
|
template <class EngineIn1, class LayoutIn1,
|
||||||
class EngineIn2, class LayoutIn2,
|
class EngineIn2, class LayoutIn2,
|
||||||
|
@ -576,7 +576,7 @@ depth(Layout<Shape,Stride> const& layout)
|
|||||||
|
|
||||||
// Return the codomain shape of a mode
|
// Return the codomain shape of a mode
|
||||||
// @post size(coshape(@a a)) == cosize(@a a)
|
// @post size(coshape(@a a)) == cosize(@a a)
|
||||||
// @return C Coordinate with smallest elements such that that
|
// @return C Coordinate with smallest elements such that
|
||||||
// @a elem_less(sub_layout(c), C) for all c < size(@a sub_layout)
|
// @a elem_less(sub_layout(c), C) for all c < size(@a sub_layout)
|
||||||
// where sub_layout = get<Is...>(layout).
|
// where sub_layout = get<Is...>(layout).
|
||||||
template <int... Is, class Shape, class Stride>
|
template <int... Is, class Shape, class Stride>
|
||||||
|
@ -527,7 +527,7 @@ public:
|
|||||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||||
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||||
|
|
||||||
// Allocate the the accumulators for the (M,N) blk_shape
|
// Allocate the accumulators for the (M,N) blk_shape
|
||||||
//
|
//
|
||||||
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
|
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
|
||||||
auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||||
|
@ -540,7 +540,7 @@ public:
|
|||||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||||
|
|
||||||
// Allocate the the accumulators for the (M,N) blk_shape
|
// Allocate the accumulators for the (M,N) blk_shape
|
||||||
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||||
|
|
||||||
// Order two Math WG's MMA one after the other, helps hide Epilogue
|
// Order two Math WG's MMA one after the other, helps hide Epilogue
|
||||||
|
@ -347,7 +347,7 @@ public:
|
|||||||
// The number of tiles for which reduction is required is either:
|
// The number of tiles for which reduction is required is either:
|
||||||
// (a) the total number of output tiles (in the case of split-K)
|
// (a) the total number of output tiles (in the case of split-K)
|
||||||
// (b) the number of stream-K tiles
|
// (b) the number of stream-K tiles
|
||||||
// To calculate the the total number of output tiles in the split-K case, we
|
// To calcualte the total number of output tiles in the split-K case, we
|
||||||
// note that, in the split-K case, the units_per_problem_ member of Params will be
|
// note that, in the split-K case, the units_per_problem_ member of Params will be
|
||||||
// the total number of output tiles.
|
// the total number of output tiles.
|
||||||
auto reduction_tiles = params.splits_ > 1 ? params.units_per_problem_ : params.sk_tiles_;
|
auto reduction_tiles = params.splits_ > 1 ? params.units_per_problem_ : params.sk_tiles_;
|
||||||
|
@ -556,7 +556,7 @@ public:
|
|||||||
constexpr auto WarpThreadLayout = make_layout(make_shape(Int<WarpThreadShapeN>{}, Int<WarpThreadShapeK>{}));
|
constexpr auto WarpThreadLayout = make_layout(make_shape(Int<WarpThreadShapeN>{}, Int<WarpThreadShapeK>{}));
|
||||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
/// A warp group uses 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize.
|
/// A warp group uses 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize.
|
||||||
/// Divide a warp_group_tile into 8x8 warp_tiles to futher reduce the reg usage.
|
/// Divide a warp_group_tile into 8x8 warp_tiles to further reduce the reg usage.
|
||||||
/// Step 0: Step 1: Step 2: Step 3:
|
/// Step 0: Step 1: Step 2: Step 3:
|
||||||
/// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
/// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
||||||
/// W1 W0 -- -- -- -- -- -- -- -- W3 W2 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
/// W1 W0 -- -- -- -- -- -- -- -- W3 W2 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
||||||
|
@ -47,7 +47,7 @@ namespace cutlass {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Wmma array type (WmmaFragmentArray holds elements of of type nvcuda::wmma::fragment)
|
/// Wmma array type (WmmaFragmentArray holds elements of type nvcuda::wmma::fragment)
|
||||||
template <
|
template <
|
||||||
/// Element type
|
/// Element type
|
||||||
typename T,
|
typename T,
|
||||||
|
@ -116,7 +116,7 @@ would include the following.
|
|||||||
access instructions (like `cp.async`), then dispatch to the
|
access instructions (like `cp.async`), then dispatch to the
|
||||||
custom instruction.
|
custom instruction.
|
||||||
|
|
||||||
2. The the two `Tensor`s have static layouts and it can be proven
|
2. The two `Tensor`s have static layouts and it can be proven
|
||||||
that element vectorization is valid -- for example, four `LDS.32`s
|
that element vectorization is valid -- for example, four `LDS.32`s
|
||||||
can be combined into a single `LDS.128` -- then vectorize the source
|
can be combined into a single `LDS.128` -- then vectorize the source
|
||||||
and destinations tensors.
|
and destinations tensors.
|
||||||
|
@ -37,7 +37,7 @@ and the `Layout`s of threads and values within the operation.
|
|||||||
The `MMA_Traits` struct takes the Operation as a template parameter.
|
The `MMA_Traits` struct takes the Operation as a template parameter.
|
||||||
CuTe specializes `MMA_Traits` for each Operation type that it supports.
|
CuTe specializes `MMA_Traits` for each Operation type that it supports.
|
||||||
|
|
||||||
Together, these two types comprise an "Atom" that decouples the complexity of thread and data layouts from the call site of of the PTX instruction. The Atom's Traits struct exposes information that is relevant to a single MMA operation, no matter the granularity at which it operates.
|
Together, these two types comprise an "Atom" that decouples the complexity of thread and data layouts from the call site of the PTX instruction. The Atom's Traits struct exposes information that is relevant to a single MMA operation, no matter the granularity at which it operates.
|
||||||
|
|
||||||
CuTe MMA atoms expose the semantics of a single MMA operation.
|
CuTe MMA atoms expose the semantics of a single MMA operation.
|
||||||
This is true regardless of the hardware level at which the MMA operates.
|
This is true regardless of the hardware level at which the MMA operates.
|
||||||
|
@ -255,7 +255,7 @@ int bar()
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
"Static" is an unfortunately overloaded term in C++. Sometimes it means "the opposite of instance," like a "static function" or "static member" of a class. (Some programming languages, like Java, say "class method" to refer to a "static function of a class.") That's not what we mean here. Instead, we mean "part of a compile-time type." For example, `Int<1>` encodes the value 1 at compile time, as part of the type of a templated class `Int<Value>`. `Int<3>` and `Int<4>` have different types. You can get the value of of the type like this: `Int<3>::value`. (The `value` is a `static constexpr` member of the class, where "static" means "opposite of instance.") As soon as you go from `Int<3>` to `Int<3>::value`, you've gone from (3) above (a compile-time value) to (2) above (a `constexpr` value). In some situations, this may mean that the compiler treats it as a run-time value.
|
"Static" is an unfortunately overloaded term in C++. Sometimes it means "the opposite of instance," like a "static function" or "static member" of a class. (Some programming languages, like Java, say "class method" to refer to a "static function of a class.") That's not what we mean here. Instead, we mean "part of a compile-time type." For example, `Int<1>` encodes the value 1 at compile time, as part of the type of a templated class `Int<Value>`. `Int<3>` and `Int<4>` have different types. You can get the value of the type like this: `Int<3>::value`. (The `value` is a `static constexpr` member of the class, where "static" means "opposite of instance.") As soon as you go from `Int<3>` to `Int<3>::value`, you've gone from (3) above (a compile-time value) to (2) above (a `constexpr` value). In some situations, this may mean that the compiler treats it as a run-time value.
|
||||||
|
|
||||||
#### Strides
|
#### Strides
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ You may explicitly exclude cuBLAS and cuDNN as dependencies with the following C
|
|||||||
|
|
||||||
## Build and run the CUTLASS Profiler
|
## Build and run the CUTLASS Profiler
|
||||||
|
|
||||||
From the `build/` directory created above, compile the the CUTLASS Profiler.
|
From the `build/` directory created above, compile the CUTLASS Profiler.
|
||||||
```bash
|
```bash
|
||||||
$ make cutlass_profiler -j12
|
$ make cutlass_profiler -j12
|
||||||
```
|
```
|
||||||
|
@ -696,7 +696,7 @@ bool TestAllConv2d(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts
|
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts
|
||||||
if (CutlassUnitTestProblemCount() &&
|
if (CutlassUnitTestProblemCount() &&
|
||||||
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
|
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
|
||||||
return true;
|
return true;
|
||||||
@ -742,7 +742,7 @@ bool TestAllConv2d(
|
|||||||
}
|
}
|
||||||
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
||||||
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
||||||
// which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
|
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
||||||
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
||||||
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
||||||
{1, 17, 11, 288}, // input size (NHWC)
|
{1, 17, 11, 288}, // input size (NHWC)
|
||||||
@ -784,7 +784,7 @@ bool TestAllConv2d(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts
|
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts
|
||||||
if (CutlassUnitTestProblemCount() &&
|
if (CutlassUnitTestProblemCount() &&
|
||||||
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
|
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -609,7 +609,7 @@ bool TestAllInterleavedConv2d(
|
|||||||
#if 0
|
#if 0
|
||||||
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
||||||
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
||||||
// which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
|
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
||||||
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
||||||
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
||||||
{1, 17, 11, 288}, // input size (NHWC)
|
{1, 17, 11, 288}, // input size (NHWC)
|
||||||
|
@ -632,7 +632,7 @@ bool TestAllConv2dWithBroadcast(
|
|||||||
|
|
||||||
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
||||||
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
||||||
// which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
|
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
||||||
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
||||||
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
||||||
{1, 17, 11, 288}, // input size (NHWC)
|
{1, 17, 11, 288}, // input size (NHWC)
|
||||||
|
@ -587,7 +587,7 @@ bool TestAllConv2dWithReduction(
|
|||||||
|
|
||||||
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
||||||
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
||||||
// which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
|
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
||||||
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
||||||
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
||||||
{1, 17, 11, 288}, // input size (NHWC)
|
{1, 17, 11, 288}, // input size (NHWC)
|
||||||
|
@ -613,7 +613,7 @@ bool TestAllConv3d(
|
|||||||
|
|
||||||
// Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for
|
// Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for
|
||||||
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
||||||
// which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep
|
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
||||||
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
||||||
cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size (
|
cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size (
|
||||||
{1, 8, 8, 8, 32}, // input size (NDHWC)
|
{1, 8, 8, 8, 32}, // input size (NDHWC)
|
||||||
|
Loading…
Reference in New Issue
Block a user