 cc3c29a81a
			
		
	
	
		cc3c29a81a
		
			
		
	
	
	
	
		
			
			* v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
		
			
				
	
	
		
			423 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			423 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| #pragma once
 | |
| 
 | |
| #include "cutlass/numeric_types.h"
 | |
| 
 | |
| #if !defined(__CUDACC_RTC__)
 | |
| #include <cuda.h>
 | |
| #include <cinttypes>
 | |
| #endif
 | |
| 
 | |
| #include <cute/config.hpp>
 | |
| 
 | |
| #include <cute/arch/util.hpp>   // cute::cast_smem_ptr_to_uint
 | |
| #include <cute/arch/config.hpp> // CUTE_ARCH_TMA_SMxx_ENABLED
 | |
| #include <cute/arch/copy.hpp>
 | |
| #include <cute/arch/copy_sm90.hpp>
 | |
| 
 | |
| #include <cute/container/alignment.hpp>
 | |
| #include <cute/container/bit_field.hpp>
 | |
| #include <cute/container/array.hpp>
 | |
| #include <cute/numeric/numeric_types.hpp>
 | |
| 
 | |
| namespace cute
 | |
| {
 | |
| 
 | |
| //////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| /// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns
 | |
| /// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels)
 | |
| /// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction)
 | |
| //////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| // Initialize barrier present in shared memory
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| initialize_barrier(uint64_t& smem_barrier,                 // 64 bits user-manged barrier in smem
 | |
|                    int thread_count = 1)                   // Thread count expected to arrive/wait on this barrier
 | |
| {
 | |
| #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
 | |
|   uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
 | |
|   asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n"
 | |
|     :: "r"(smem_int_ptr),
 | |
|        "r"(thread_count));
 | |
| #endif
 | |
| }
 | |
| 
 | |
| // Set the number of bytes transfered per transaction and perform an arrive operation as well
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| set_barrier_transaction_bytes(uint64_t& smem_barrier,      // 64 bits user-manged barrier in smem
 | |
|                               uint32_t bytes)              // Number of bytes transfered by per TMA transaction
 | |
| {
 | |
| #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
 | |
|   uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
 | |
|   asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n"
 | |
|     :: "r"(smem_int_ptr),
 | |
|        "r"(bytes));
 | |
| #endif
 | |
| }
 | |
| 
 | |
| // Barrier wait
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| wait_barrier(uint64_t& smem_barrier,                       // 64 bits user-manged barrier in smem
 | |
|              int phase_bit)                                // Current phase bit the barrier waiting to flip
 | |
| {
 | |
| #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
 | |
|   uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
 | |
|   asm volatile(
 | |
|     "{\n"
 | |
|     ".reg .pred                P1;\n"
 | |
|     "LAB_WAIT:\n"
 | |
|     "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
 | |
|     "@P1                       bra DONE;\n"
 | |
|     "bra                   LAB_WAIT;\n"
 | |
|     "DONE:\n"
 | |
|     "}\n"
 | |
|     :: "r"(smem_int_ptr),
 | |
|        "r"(phase_bit));
 | |
| 
 | |
| #endif
 | |
| }
 | |
| 
 | |
| // Barrier arrive
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| arrive_barrier(uint64_t& smem_barrier)                      // 64 bits user-manged barrier in smem
 | |
| {
 | |
| #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
 | |
|   uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier);
 | |
|   asm volatile(
 | |
|     "{\n"
 | |
|     ".reg .b64 state; \n"
 | |
|     "mbarrier.arrive.shared::cta.b64   state, [%0];\n"
 | |
|     "}\n"
 | |
|     :: "r"(smem_int_ptr));
 | |
| #endif
 | |
| }
 | |
| 
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| // TMA Descriptor and utilities
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| namespace TMA {
 | |
| 
 | |
| enum class SmemSwizzleBits : uint8_t {
 | |
|   DISABLE = 0,
 | |
|   B32 = 1,
 | |
|   B64 = 2,
 | |
|   B128 = 3,
 | |
| };
 | |
| 
 | |
| enum class SmemSwizzleBase : uint8_t {
 | |
|   SWIZZLE_BASE_16B         = 0,
 | |
| };
 | |
| 
 | |
| enum class OOBFill : uint8_t {
 | |
|   ZERO = 0,
 | |
|   CONSTANT = 1,
 | |
| };
 | |
| 
 | |
| CUTE_HOST_DEVICE char const* to_string(OOBFill const& t) {
 | |
|   switch (t) {
 | |
|     case OOBFill::ZERO:     return "ZERO";
 | |
|     case OOBFill::CONSTANT: return "CONSTANT";
 | |
|   }
 | |
|   return nullptr;
 | |
| }
 | |
| 
 | |
| enum class L2Promotion : uint8_t {
 | |
|   DISABLE = 0,
 | |
|   B64 = 1,
 | |
|   B128 = 2,
 | |
|   B256 = 3,
 | |
| };
 | |
| 
 | |
| CUTE_HOST_DEVICE char const* to_string(L2Promotion const& t) {
 | |
|   switch (t) {
 | |
|     case L2Promotion::DISABLE: return "DISABLE";
 | |
|     case L2Promotion::B64:     return "B64";
 | |
|     case L2Promotion::B128:    return "B128";
 | |
|     case L2Promotion::B256:    return "B256";
 | |
|   }
 | |
|   return nullptr;
 | |
| }
 | |
| 
 | |
| // Aux parameters which are independent with the problem size
 | |
| struct DescriptorAuxParams {
 | |
|   OOBFill     oobfill_     = OOBFill::ZERO;
 | |
|   L2Promotion l2promo_     = L2Promotion::DISABLE;
 | |
| };
 | |
| 
 | |
| enum class CacheHintSm90 : uint64_t {
 | |
|   EVICT_NORMAL = 0x1000000000000000,
 | |
|   EVICT_FIRST = 0x12F0000000000000,
 | |
|   EVICT_LAST = 0x14F0000000000000,
 | |
| };
 | |
| 
 | |
| #if (__CUDACC_VER_MAJOR__ >= 12)
 | |
| 
 | |
| #if !defined(__CUDACC_RTC__)
 | |
| /// @return The TMA descriptor datatype enum corresponding to T.
 | |
| template <class T>
 | |
| inline CUtensorMapDataType
 | |
| to_CUtensorMapDataType() {
 | |
|   if constexpr (is_same_v<T,       int8_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;    } else
 | |
|   if constexpr (is_same_v<T,      uint8_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;    } else
 | |
|   if constexpr (is_same_v<T, float_e4m3_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;    } else
 | |
|   if constexpr (is_same_v<T, float_e5m2_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;    } else
 | |
|   if constexpr (is_same_v<T,     uint16_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT16;   } else
 | |
|   if constexpr (is_same_v<T,     uint32_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT32;   } else
 | |
|   if constexpr (is_same_v<T,     uint64_t>) { return CU_TENSOR_MAP_DATA_TYPE_UINT64;   } else
 | |
|   if constexpr (is_same_v<T,      int32_t>) { return CU_TENSOR_MAP_DATA_TYPE_INT32;    } else
 | |
|   if constexpr (is_same_v<T,      int64_t>) { return CU_TENSOR_MAP_DATA_TYPE_INT64;    } else
 | |
|   if constexpr (is_same_v<T,       half_t>) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;  } else
 | |
|   if constexpr (is_same_v<T,        float>) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;  } else
 | |
|   if constexpr (is_same_v<T,       double>) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;  } else
 | |
|   if constexpr (is_same_v<T,   bfloat16_t>) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else
 | |
|   if constexpr (is_same_v<T,   tfloat32_t>) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else
 | |
|   { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); }
 | |
| }
 | |
| 
 | |
| inline CUtensorMapSwizzle
 | |
| to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) {
 | |
|   switch (t) {
 | |
|     default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!");
 | |
|     case SmemSwizzleBits::DISABLE: 
 | |
|       assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits.");
 | |
|       return CU_TENSOR_MAP_SWIZZLE_NONE;
 | |
|     case SmemSwizzleBits::B32:
 | |
|       assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 32B swizzle bits.");
 | |
|       return CU_TENSOR_MAP_SWIZZLE_32B;
 | |
|     case SmemSwizzleBits::B64:
 | |
|       assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits.");
 | |
|       return CU_TENSOR_MAP_SWIZZLE_64B;
 | |
|     case SmemSwizzleBits::B128:
 | |
|       assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 128B swizzle bits.");
 | |
|       return CU_TENSOR_MAP_SWIZZLE_128B;
 | |
|   }
 | |
| }
 | |
| 
 | |
| inline CUtensorMapFloatOOBfill
 | |
| to_CUtensorMapFloatOOBfill(OOBFill const& t) {
 | |
|   switch(t) {
 | |
|     default:                assert(false && "Unknown OOBFill!");
 | |
|     case OOBFill::ZERO:     return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
 | |
|     case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA;
 | |
|   }
 | |
| }
 | |
| 
 | |
| inline CUtensorMapL2promotion
 | |
| to_CUtensorMapL2promotion(L2Promotion const& t) {
 | |
|   switch(t) {
 | |
|     default: assert(false && "Unknown L2Promotion!");
 | |
|     case L2Promotion::DISABLE: return CU_TENSOR_MAP_L2_PROMOTION_NONE;
 | |
|     case L2Promotion::B64:     return CU_TENSOR_MAP_L2_PROMOTION_L2_64B;
 | |
|     case L2Promotion::B128:    return CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
 | |
|     case L2Promotion::B256:    return CU_TENSOR_MAP_L2_PROMOTION_L2_256B;
 | |
|   }
 | |
| }
 | |
| 
 | |
| #endif // !defined(__CUDACC_RTC__)
 | |
| 
 | |
| #endif // (__CUDACC_VER_MAJOR__ >= 12)
 | |
| 
 | |
| } // end namespace TMA
 | |
| 
 | |
| #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)
 | |
|   using TmaDescriptor = CUtensorMap;
 | |
|   using Im2ColTmaDescriptor = CUtensorMap;
 | |
| #else
 | |
|   using TmaDescriptor = struct alignas(64) { char bytes[128]; };
 | |
|   using Im2ColTmaDescriptor = struct alignas(64) { char bytes[128]; };
 | |
| #endif
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| /// Initiates a TensorMap Prefetch
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| prefetch_tma_descriptor(TmaDescriptor const* desc_ptr)
 | |
| {
 | |
| #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
 | |
|   uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
 | |
|   // Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param)
 | |
|   asm volatile (
 | |
|     "prefetch.tensormap [%0];"
 | |
|     :
 | |
|     : "l"(gmem_int_desc)
 | |
|     : "memory");
 | |
| #else
 | |
|   CUTE_INVALID_CONTROL_PATH("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED.");
 | |
| #endif
 | |
| }
 | |
| 
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| /// Perform a TensorMap modification (by each field)
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| // Replace tensor pointer directly in GMEM
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr,
 | |
|                                           void const* const new_tensor_ptr)
 | |
| {
 | |
| #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
 | |
|   uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
 | |
|   uint64_t const new_desc_addr = reinterpret_cast<uint64_t>(new_tensor_ptr);
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;"
 | |
|     :: "l"(gmem_int_desc), "l"(new_desc_addr));
 | |
| #else
 | |
|   CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3");
 | |
| #endif
 | |
| }
 | |
| 
 | |
| // Replace tensor pointer by bringing the tensormap from GMEM into the shared memory
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc,
 | |
|                                           void const* const new_tensor_ptr)
 | |
| {
 | |
| #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
 | |
|   uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
 | |
|   uint64_t const new_desc_addr = reinterpret_cast<uint64_t>(new_tensor_ptr);
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;"
 | |
|     :: "r"(smem_int_desc), "l"(new_desc_addr));
 | |
| #else
 | |
|   CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3");
 | |
| #endif
 | |
| }
 | |
| 
 | |
| // Replace tensor dims and strides for GEMMs by bringing the tensormap from GMEM into the shared memory
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor                 & smem_desc,
 | |
|                                                   cute::array<uint32_t, 3> const& prob_shape,
 | |
|                                                   cute::array<uint64_t, 3> const& prob_stride)
 | |
| {
 | |
| #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
 | |
|   uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
 | |
|   uint64_t const smem_int64_desc = 0;
 | |
|   asm volatile (
 | |
|     "cvt.u64.u32 %0, %1;"
 | |
|     :: "l"(smem_int64_desc), "r"(smem_int_desc));
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;"
 | |
|     :: "l"(smem_int64_desc), "r"(prob_shape[0]));
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;"
 | |
|     :: "l"(smem_int64_desc), "r"(prob_shape[1]));
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;"
 | |
|     :: "l"(smem_int64_desc), "r"(prob_shape[2]));
 | |
|   // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
 | |
|   #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5)))
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;"
 | |
|     :: "l"(smem_int64_desc), "l"(prob_stride[1]));
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
 | |
|     :: "l"(smem_int64_desc), "l"(prob_stride[2]));
 | |
|   #else
 | |
|   // 4 LSBs are not included
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;"
 | |
|     :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4));
 | |
|   asm volatile (
 | |
|     "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
 | |
|     :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4));
 | |
|   #endif
 | |
| #else
 | |
|   CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3");
 | |
| #endif
 | |
| }
 | |
| 
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| /// Perform a fused copy and fence operation (needed when modifying tensormap in shared memory)
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescriptor& smem_desc)
 | |
| {
 | |
| #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
 | |
|   uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
 | |
|   uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
 | |
|   asm volatile (
 | |
|     "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;"
 | |
|     :: "l"(gmem_int_desc), "r"(smem_int_desc));
 | |
| #else
 | |
|   CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3");
 | |
| #endif
 | |
| }
 | |
| 
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| /// Perform a release fence operation (needed when modifying tensormap directly in GMEM)
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| tma_descriptor_fence_release()
 | |
| {
 | |
| #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
 | |
|   asm volatile ("fence.proxy.tensormap::generic.release.gpu;");
 | |
| #else
 | |
|   CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3");
 | |
| #endif
 | |
| }
 | |
| 
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| /// Perform a acquire fence operation
 | |
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr)
 | |
| {
 | |
| #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
 | |
|   uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
 | |
|   asm volatile (
 | |
|     "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;"
 | |
|     :
 | |
|     : "l"(gmem_int_desc)
 | |
|     : "memory");
 | |
| #else
 | |
|   CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3");
 | |
| #endif
 | |
| }
 | |
| 
 | |
| ///////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| } // end namespace cute
 |