Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_global_stream.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/coord.h"
33 #include "cutlass/convert.h"
36 
37 namespace cutlass {
38 namespace gemm {
39 
41 
42 template <
44  GemmOperand::Kind Operand,
46  typename LoadIterator_,
48  typename StoreIterator_,
50  typename Transformer_>
51 
54  static GemmOperand::Kind const kOperand = Operand;
56  typedef LoadIterator_ LoadIterator;
58  typedef Transformer_ Transformer;
60  typedef StoreIterator_ StoreIterator;
61 
63  typedef typename LoadIterator::Fragment FetchedFragment;
65  typedef typename Transformer::OutputFragment TransformedFragment;
68  "");
73  "");
74 
76  static MatrixLayout::Kind const kLayout = LoadIterator::kLayout;
78  typedef typename LoadIterator::Scalar Scalar;
80  typedef typename LoadIterator::Pointer Pointer;
82  typedef typename LoadIterator::Index Index;
84  typedef typename LoadIterator::Tile Tile;
85 
89 
92 
94  struct Params {
95  // The load iterator.
96  typename LoadIterator::Params load_iterator;
97  // The store iterator.
98  typename StoreIterator::Params store_iterator;
99  // Offset to residue.
101 
104  long long batch_stride,
105  Index ldm,
106  Index _offset_to_residue) {
107 
108  offset_to_residue = _offset_to_residue;
109  int error_code = load_iterator.initialize(pointer, batch_stride, ldm);
110  if (error_code) {
111  return error_code;
112  }
113  return store_iterator.initialize();
114  }
115  };
116 
120  struct SharedStorage {};
121 
122  //
123  // Static member functions
124  //
125 
127  CUTLASS_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) {
128  bool const kKstrided =
131  return make_Coord(
132  tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
133  }
134 
136  CUTLASS_DEVICE GlobalLoadStream(
137  Params const& _params,
138  SharedStorage& shared_storage,
139  ThreadblockTileRef const& threadblock_tile_ref,
140  Coord<3> const bounds,
141  Coord<3> const& _threadblock_offset)
142  : params(_params),
144  threadblock_offset(project_coordinate(_threadblock_offset)),
146  project_coordinate(bounds, 1), /*multiplicant_bounds*/
147  project_coordinate(_threadblock_offset) /*threablock_offset*/),
148  transformer(),
149  store_iterator(params.store_iterator, threadblock_tile_ref.data())
150  {
151  load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
152  fetched_fragment.clear();
153  }
154 
155 
157  CUTLASS_DEVICE void copy() { load_iterator.load_post_increment(fetched_fragment); }
158 
160  CUTLASS_DEVICE void commit() {
162  store_iterator.store_post_increment(transformed_fragment);
163  store_iterator.inc_stage();
164  }
165 
167  CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
168  load_iterator.residue(k);
169  if (!skip_clear) {
170  fetched_fragment.clear();
171  }
172  }
173 
175  CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
176  Index kResidue = k % kTileK;
177  if (kResidue) {
178  residue(kResidue);
179  }
180  load_iterator.add_pointer_offset(params.offset_to_residue * load_iterator.stride_advance());
181  }
182 
184  CUTLASS_DEVICE void rollback(void) {
185  load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
186 
187  int const kBlock = kOperand == GemmOperand::kA
188  ? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
189  : (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
190 
191  load_iterator.add_pointer_offset(-(params.offset_to_residue + kBlock) *
192  load_iterator.stride_advance());
193  }
194 
196  CUTLASS_DEVICE GlobalLoadStream &operator+=(Coord<3> const &offset) {
197  load_iterator += offset;
198  return *this;
199  }
200 
201  //
202  // Data members
203  //
204 
221 };
222 
224 } // namespace gemm
225 } // namespace cutlass
ThreadblockTileStorage::TensorRef ThreadblockTileRef
Tensor reference to threadblock tile.
Definition: gemm_global_stream.h:91
LoadIterator::Pointer Pointer
The pointer.
Definition: gemm_global_stream.h:80
LoadIterator load_iterator
The iterator.
Definition: gemm_global_stream.h:212
Definition: convert.h:33
StoreIterator store_iterator
The store iterator.
Definition: gemm_global_stream.h:220
Params params
Parameters.
Definition: gemm_global_stream.h:206
Defines iterators for efficiently loading and storing to global memory.
std::is_same (false specialization)
Definition: platform.h:420
TensorRef< Scalar, 4 > TensorRef
Defines the tensor reference for this allocation.
Definition: tile_allocation.h:62
static GemmOperand::Kind const kOperand
Indicates the type of GEMM operand.
Definition: gemm_global_stream.h:54
CUTLASS_DEVICE GlobalLoadStream & operator+=(Coord< 3 > const &offset)
Adds a Coord<3> to the underlying global load iterator.
Definition: gemm_global_stream.h:196
CUTLASS_DEVICE void copy()
Load the data from shared memory to the fetch fragment.
Definition: gemm_global_stream.h:157
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Coord< 3 > multiplicand_bounds
Multiplicand bounds.
Definition: gemm_global_stream.h:208
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:318
static MatrixLayout::Kind const kLayout
Make sure the transformed fragment is the same as the store fragment.
Definition: gemm_global_stream.h:76
StoreIterator::Params store_iterator
Definition: gemm_global_stream.h:98
FetchedFragment fetched_fragment
The fragment to fetch from shared memory.
Definition: gemm_global_stream.h:214
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long batch_stride, Index ldm, Index _offset_to_residue)
Setup the params.
Definition: gemm_global_stream.h:103
Definition: gemm_global_stream.h:52
Definition: gemm_global_stream.h:120
LoadIterator::Scalar Scalar
The scalar type of the iterator.
Definition: gemm_global_stream.h:78
CUTLASS_DEVICE void residue(Index k, bool skip_clear=false)
Execute the residue code.
Definition: gemm_global_stream.h:167
TransformedFragment transformed_fragment
The fragment to convert the data after it has been fetched from shared memory.
Definition: gemm_global_stream.h:218
Definition: matrix_traits.h:159
Defines a fragment based on a Shape<> template.
Index offset_to_residue
Definition: gemm_global_stream.h:100
TransformedFragment Fragment
Make sure the fragments match.
Definition: gemm_global_stream.h:68
LoadIterator_ LoadIterator
The load iterator.
Definition: gemm_global_stream.h:56
Definition: gemm_operand.h:67
CUTLASS_DEVICE void commit()
Commit the data.
Definition: gemm_global_stream.h:160
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Class for storing a tile in memory and accessing it through a tensor ref.
Definition: tile_allocation.h:41
Transformer transformer
The transformer.
Definition: gemm_global_stream.h:216
#define static_assert(__e, __m)
Definition: platform.h:153
Definition: gemm_operand.h:96
StoreIterator_ StoreIterator
The store iterator to write to shared memory.
Definition: gemm_global_stream.h:60
Definition: matrix_traits.h:159
TileAllocation< typename StoreIterator::Scalar, typename StoreIterator::Tile > ThreadblockTileStorage
Shared memory allocation for the tile.
Definition: gemm_global_stream.h:88
LoadIterator::Params load_iterator
Definition: gemm_global_stream.h:96
The params.
Definition: gemm_global_stream.h:94
Transformer_ Transformer
The transformer.
Definition: gemm_global_stream.h:58
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Coord< 3 > threadblock_offset
Threadblock offset.
Definition: gemm_global_stream.h:210
LoadIterator::Index Index
The index.
Definition: gemm_global_stream.h:82
static CUTLASS_DEVICE Coord< 3 > project_coordinate(Coord< 3 > const &coord, Index d_offset=0)
Maps a coordinate in the GEMM&#39;s (K, N, M) coordinate system to global memory.
Definition: gemm_global_stream.h:127
Kind
Definition: matrix_traits.h:357
Definition: matrix_traits.h:357
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK)
Move to the residue portion.
Definition: gemm_global_stream.h:175
LoadIterator::Fragment FetchedFragment
The fragment that is copied from shared memory.
Definition: gemm_global_stream.h:63
Transformer::OutputFragment TransformedFragment
The fragment that is obtained after the transformation by the transformer.
Definition: gemm_global_stream.h:65
CUTLASS_DEVICE GlobalLoadStream(Params const &_params, SharedStorage &shared_storage, ThreadblockTileRef const &threadblock_tile_ref, Coord< 3 > const bounds, Coord< 3 > const &_threadblock_offset)
Ctor.
Definition: gemm_global_stream.h:136
Defines conversion operations among Fragments of different base type.
LoadIterator::Tile Tile
The tile.
Definition: gemm_global_stream.h:84
CUTLASS_DEVICE void rollback(void)
Rollback to the beginning of the first tile.
Definition: gemm_global_stream.h:184