Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_global_stream.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include <cutlass/convert.h>
35 
36 namespace cutlass {
37 namespace gemm {
38 
40 
41 template <
43  typename LoadIterator_,
45  typename StoreIterator_,
47  typename Transformer_>
48 
51  typedef LoadIterator_ LoadIterator;
53  typedef Transformer_ Transformer;
55  typedef StoreIterator_ StoreIterator;
56 
58  typedef typename LoadIterator::Fragment FetchedFragment;
60  typedef typename Transformer::OutputFragment TransformedFragment;
63  "");
68  "");
69 
71  static MatrixLayout::Kind const kLayout = LoadIterator::kLayout;
73  typedef typename LoadIterator::Scalar Scalar;
75  typedef typename LoadIterator::Pointer Pointer;
77  typedef typename LoadIterator::Index Index;
78 
80  struct Params {
81  // The load iterator.
82  typename LoadIterator::Params load_iterator;
83  // The store iterator.
84  typename StoreIterator::Params store_iterator;
85 
88  int error_code = load_iterator.initialize(pointer, ld);
89  if (error_code) {
90  return error_code;
91  }
92 
93  return store_iterator.initialize();
94  }
95  };
96 
98  typedef typename StoreIterator::SharedStorage SharedStoreStorage;
99 
102  // The load iterator.
103  typename LoadIterator::SharedStorage load_iterator;
104  // The store iterator.
106  };
107 
109  CUTLASS_DEVICE GlobalLoadStreamBase(Params const& params,
110  SharedStorage& shared_storage,
111  Coord<3> const bounds,
112  Coord<3> const& block)
113  : load_iterator(params.load_iterator, bounds, block),
114  transformer(),
115  store_iterator(params.store_iterator, shared_storage.store_iterator)
116 
117  {
118  fetched_fragment.clear();
119  }
120 
122  CUTLASS_DEVICE void copy() { iterator_load(load_iterator, fetched_fragment); }
123 
125  CUTLASS_DEVICE void commit() {
128  store_iterator.inc_stage();
129  }
130 
132  CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
133  load_iterator.residue(k);
134  if (!skip_clear) {
135  fetched_fragment.clear();
136  }
137  }
138 
149 };
150 
152 
153 template <
155  typename LoadIterator_,
157  typename StoreIterator_,
159  typename Transformer_ = Copy<typename LoadIterator_::Fragment> >
160 
161 struct GlobalLoadStream : public GlobalLoadStreamBase<LoadIterator_, StoreIterator_, Transformer_> {
164 
166  CUTLASS_DEVICE GlobalLoadStream(typename Base::Params const& params,
167  typename Base::SharedStorage& shared_storage,
168  Coord<3> const& bounds,
169  Coord<3> const& block)
170  : Base(params, shared_storage, bounds, block) {}
171 };
172 
174 } // namespace gemm
175 } // namespace cutlass
static MatrixLayout::Kind const kLayout
Make sure the transformed fragment is the same as the store fragment.
Definition: gemm_global_stream.h:71
StoreIterator::Params store_iterator
Definition: gemm_global_stream.h:84
Definition: convert.h:33
Defines iterators for efficiently loading and storing to global memory.
Transformer_ Transformer
The transformer.
Definition: gemm_global_stream.h:53
StoreIterator_ StoreIterator
The store iterator to write to shared memory.
Definition: gemm_global_stream.h:55
std::is_same (false specialization)
Definition: platform.h:412
StoreIterator::SharedStorage SharedStoreStorage
The amount of storage in shared memory needed to store the tile.
Definition: gemm_global_stream.h:98
TransformedFragment Fragment
Make sure the fragments match.
Definition: gemm_global_stream.h:63
TransformedFragment transformed_fragment
The fragment to convert the data after it has been fetched from shared memory.
Definition: gemm_global_stream.h:146
CUTLASS_DEVICE void residue(Index k, bool skip_clear=false)
Execute the residue code.
Definition: gemm_global_stream.h:132
Definition: convert.h:69
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, Index ld)
Setup the params.
Definition: gemm_global_stream.h:87
LoadIterator load_iterator
The iterator.
Definition: gemm_global_stream.h:140
LoadIterator::Params load_iterator
Definition: gemm_global_stream.h:82
Definition: gemm_global_stream.h:161
Free functions for loading and storing to implementations of tile iteartor concepts.
LoadIterator::SharedStorage load_iterator
Definition: gemm_global_stream.h:103
CUTLASS_DEVICE GlobalLoadStream(typename Base::Params const &params, typename Base::SharedStorage &shared_storage, Coord< 3 > const &bounds, Coord< 3 > const &block)
Ctor.
Definition: gemm_global_stream.h:166
Definition: gemm_global_stream.h:49
StoreIterator store_iterator
The store iterator.
Definition: gemm_global_stream.h:148
LoadIterator::Pointer Pointer
The pointer.
Definition: gemm_global_stream.h:75
SharedStoreStorage store_iterator
Definition: gemm_global_stream.h:105
Transformer::OutputFragment TransformedFragment
The fragment that is obtained after the transformation by the transformer.
Definition: gemm_global_stream.h:60
LoadIterator::Scalar Scalar
The scalar type of the iterator.
Definition: gemm_global_stream.h:73
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
#define static_assert(__e, __m)
Definition: platform.h:145
LoadIterator::Index Index
The index.
Definition: gemm_global_stream.h:77
Transformer transformer
The transformer.
Definition: gemm_global_stream.h:144
GlobalLoadStreamBase< LoadIterator_, StoreIterator_, Transformer_ > Base
The base class.
Definition: gemm_global_stream.h:163
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:48
LoadIterator::Fragment FetchedFragment
The fragment that is copied from shared memory.
Definition: gemm_global_stream.h:58
The storage in shared memory needed by that stream.
Definition: gemm_global_stream.h:101
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment)
Stores a fragment to an output iterator.
Definition: iterator_access.h:193
FetchedFragment fetched_fragment
The fragment to fetch from shared memory.
Definition: gemm_global_stream.h:142
Kind
Definition: matrix_traits.h:36
LoadIterator_ LoadIterator
The load iterator.
Definition: gemm_global_stream.h:51
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from an input iterator.
Definition: iterator_access.h:41
CUTLASS_DEVICE void commit()
Commit the data.
Definition: gemm_global_stream.h:125
CUTLASS_DEVICE void copy()
Load the data from shared memory to the fetch fragment.
Definition: gemm_global_stream.h:122
CUTLASS_DEVICE GlobalLoadStreamBase(Params const &params, SharedStorage &shared_storage, Coord< 3 > const bounds, Coord< 3 > const &block)
Ctor.
Definition: gemm_global_stream.h:109
Defines conversion operations among Fragments of different base type.
The params.
Definition: gemm_global_stream.h:80