Fix the racing condition of mixed-input gemm when writing the registers (#1931)

* move two warpgroup_wait

* merge main

---------

Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
This commit is contained in:
Lain 2024-11-08 10:15:54 -08:00 committed by GitHub
parent d656afbd2a
commit 8aa95dbb88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 3 deletions

View File

@ -724,4 +724,4 @@ int main(int argc, char const **args) {
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1024,8 +1024,8 @@ public:
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_BLOCK_MAX - 1>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier
if (k_block == K_BLOCK_MAX - 1) {
warpgroup_wait<K_BLOCK_MAX - 1>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
@ -1076,8 +1076,9 @@ public:
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_BLOCK_MAX - 1>();
if (k_block == K_BLOCK_MAX - 1) { // release prior barrier
warpgroup_wait<K_BLOCK_MAX - 1>();
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}