diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 428f0bc..52824ed 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -34,6 +34,18 @@ namespace fmha { +// template +// inline __device__ void atomic_add_CAS(half2_t *address, const half2_t val) { +// uint32_t *address_as_ui = (uint32_t *)address; +// uint32_t old = *address_as_ui; +// uint32_t assumed; +// do { +// assumed = old; +// half2_t sum = __hadd2(val, reinterpret_cast(old)); +// old = atomicCAS(address_as_ui, assumed, reinterpret_cast(sum)); +// } while (assumed != old); +// } + //////////////////////////////////////////////////////////////////////////////////////////////////// template< @@ -146,6 +158,7 @@ struct Gmem_tile_qkv { #pragma unroll for (int jj = 0; jj < 4; ++jj) { atomicAdd(ptr_ + jj, reinterpret_cast(data[ii])[jj]); + // atomic_add_CAS(ptr_ + jj, reinterpret_cast(data[ii])[jj]); } } }