Reduce smem usage for Q and dO in the backward pass

From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO
have 2 buffers)
This commit is contained in:
Tri Dao 2022-06-03 16:59:11 -07:00
parent 2712aa4c8d
commit b17c6fe235

View File

@ -92,8 +92,14 @@ struct Smem_tile_without_skews {
// The number of STS in total. // The number of STS in total.
enum { STS = STS_PER_COL * STS_PER_ROW }; enum { STS = STS_PER_COL * STS_PER_ROW };
// TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads,
// we only need to store 16 * 64 * 2 = 2KB instead of 4KB.
static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS;
static constexpr int STORING_THREADS = PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA;
// The size of one buffer in bytes in shared memory. // The size of one buffer in bytes in shared memory.
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; // enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS };
// The number of buffers. // The number of buffers.
enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
// The size in bytes of total buffers. // The size in bytes of total buffers.
@ -116,7 +122,7 @@ struct Smem_tile_without_skews {
// Ctor. // Ctor.
inline __device__ Smem_tile_without_skews(void *smem, int tidx) inline __device__ Smem_tile_without_skews(void *smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem)) { : smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) {
// The row written by a thread. See doc/mma_smem_layout.xlsx. // The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW; int smem_write_row = tidx / THREADS_PER_ROW;
@ -268,8 +274,11 @@ struct Smem_tile_without_skews {
inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
uint32_t smem_ptrs[N]; uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs); this->compute_store_pointers(smem_ptrs);
// Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer.
if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) {
sts(smem_ptrs, data); sts(smem_ptrs, data);
} }
}
// Store to the tile in shared memory. // Store to the tile in shared memory.
template< int N, int M > template< int N, int M >
@ -302,6 +311,7 @@ struct Smem_tile_without_skews {
// int smem_read_buffer_; // int smem_read_buffer_;
// The buffer base offset for write. // The buffer base offset for write.
// int smem_write_buffer_; // int smem_write_buffer_;
const int tidx_;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////