diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h index f284b0ae..e509ddc6 100755 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h @@ -46,6 +46,7 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/detail/helper_macros.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/layout/pitch_linear.h" @@ -60,8 +61,8 @@ namespace threadblock { /// Predicated tile access iterator descriptor object containing template dependent state struct PredicatedTileAccessIteratorDesc { - int element_size_bits; - int advance_rank; + int element_size_bits = -1; + int advance_rank = -1; layout::PitchLinearCoord threadblock_shape; layout::PitchLinearCoord threadmap_iterations; layout::PitchLinearCoord threadmap_delta; @@ -71,7 +72,7 @@ struct PredicatedTileAccessIteratorDesc { // CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorDesc() { } + PredicatedTileAccessIteratorDesc() = default; CUTLASS_HOST_DEVICE PredicatedTileAccessIteratorDesc( @@ -232,16 +233,16 @@ struct PredicatedTileAccessIteratorParams { // Data members // /// stride of pitch-linear layout (units of Element) - LongIndex stride_; + LongIndex stride_ = 0; /// amount (in byte) to increment pointer to move to next access along /// strided dimension - LongIndex inc_strided_; + LongIndex inc_strided_ = 0; /// amount (in byte) to increment pointer from last access to first access /// of next tile - LongIndex inc_next_; + LongIndex inc_next_ = 0; /// amount (in byte) to increment pointer from first access of current tile /// to first access of next tile - LongIndex inc_advance_; + LongIndex inc_advance_ = 0; // // Methods @@ -249,6 +250,8 @@ struct PredicatedTileAccessIteratorParams { CUTLASS_HOST_DEVICE Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { + CUTLASS_ASSERT(desc.element_size_bits > 0); + CUTLASS_ASSERT(desc.advance_rank == 0 || desc.advance_rank == 1); stride_ = stride; @@ -277,9 +280,7 @@ struct PredicatedTileAccessIteratorParams { } CUTLASS_HOST_DEVICE - PredicatedTileAccessIteratorParams() { - initialize(LongIndex(0), PredicatedTileAccessIteratorDesc()); - } + PredicatedTileAccessIteratorParams() = default; CUTLASS_HOST_DEVICE PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) {