From b17c6fe235b29c091f154c00c526dffa7ec4cce8 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 3 Jun 2022 16:59:11 -0700 Subject: [PATCH] Reduce smem usage for Q and dO in the backward pass From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO have 2 buffers) --- csrc/flash_attn/src/fmha/smem_tile.h | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/csrc/flash_attn/src/fmha/smem_tile.h b/csrc/flash_attn/src/fmha/smem_tile.h index 7f6dd2f..5e67a34 100644 --- a/csrc/flash_attn/src/fmha/smem_tile.h +++ b/csrc/flash_attn/src/fmha/smem_tile.h @@ -92,8 +92,14 @@ struct Smem_tile_without_skews { // The number of STS in total. enum { STS = STS_PER_COL * STS_PER_ROW }; + // TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads, + // we only need to store 16 * 64 * 2 = 2KB instead of 4KB. + static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS; + static constexpr int STORING_THREADS = PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA; + // The size of one buffer in bytes in shared memory. - enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; + // enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; + enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS }; // The number of buffers. enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; // The size in bytes of total buffers. @@ -116,7 +122,7 @@ struct Smem_tile_without_skews { // Ctor. inline __device__ Smem_tile_without_skews(void *smem, int tidx) - : smem_(__nvvm_get_smem_pointer(smem)) { + : smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) { // The row written by a thread. See doc/mma_smem_layout.xlsx. int smem_write_row = tidx / THREADS_PER_ROW; @@ -268,7 +274,10 @@ struct Smem_tile_without_skews { inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { uint32_t smem_ptrs[N]; this->compute_store_pointers(smem_ptrs); - sts(smem_ptrs, data); + // Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer. + if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) { + sts(smem_ptrs, data); + } } // Store to the tile in shared memory. @@ -302,6 +311,7 @@ struct Smem_tile_without_skews { // int smem_read_buffer_; // The buffer base offset for write. // int smem_write_buffer_; + const int tidx_; }; ////////////////////////////////////////////////////////////////////////////////////////////////////