Reduce smem usage for Q and dO in the backward pass
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO have 2 buffers)
This commit is contained in:
parent
2712aa4c8d
commit
b17c6fe235
@ -92,8 +92,14 @@ struct Smem_tile_without_skews {
|
|||||||
// The number of STS in total.
|
// The number of STS in total.
|
||||||
enum { STS = STS_PER_COL * STS_PER_ROW };
|
enum { STS = STS_PER_COL * STS_PER_ROW };
|
||||||
|
|
||||||
|
// TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads,
|
||||||
|
// we only need to store 16 * 64 * 2 = 2KB instead of 4KB.
|
||||||
|
static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS;
|
||||||
|
static constexpr int STORING_THREADS = PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA;
|
||||||
|
|
||||||
// The size of one buffer in bytes in shared memory.
|
// The size of one buffer in bytes in shared memory.
|
||||||
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
|
// enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
|
||||||
|
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS };
|
||||||
// The number of buffers.
|
// The number of buffers.
|
||||||
enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
|
enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
|
||||||
// The size in bytes of total buffers.
|
// The size in bytes of total buffers.
|
||||||
@ -116,7 +122,7 @@ struct Smem_tile_without_skews {
|
|||||||
|
|
||||||
// Ctor.
|
// Ctor.
|
||||||
inline __device__ Smem_tile_without_skews(void *smem, int tidx)
|
inline __device__ Smem_tile_without_skews(void *smem, int tidx)
|
||||||
: smem_(__nvvm_get_smem_pointer(smem)) {
|
: smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) {
|
||||||
|
|
||||||
// The row written by a thread. See doc/mma_smem_layout.xlsx.
|
// The row written by a thread. See doc/mma_smem_layout.xlsx.
|
||||||
int smem_write_row = tidx / THREADS_PER_ROW;
|
int smem_write_row = tidx / THREADS_PER_ROW;
|
||||||
@ -268,7 +274,10 @@ struct Smem_tile_without_skews {
|
|||||||
inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
|
inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
|
||||||
uint32_t smem_ptrs[N];
|
uint32_t smem_ptrs[N];
|
||||||
this->compute_store_pointers(smem_ptrs);
|
this->compute_store_pointers(smem_ptrs);
|
||||||
sts(smem_ptrs, data);
|
// Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer.
|
||||||
|
if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) {
|
||||||
|
sts(smem_ptrs, data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store to the tile in shared memory.
|
// Store to the tile in shared memory.
|
||||||
@ -302,6 +311,7 @@ struct Smem_tile_without_skews {
|
|||||||
// int smem_read_buffer_;
|
// int smem_read_buffer_;
|
||||||
// The buffer base offset for write.
|
// The buffer base offset for write.
|
||||||
// int smem_write_buffer_;
|
// int smem_write_buffer_;
|
||||||
|
const int tidx_;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user