diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu index 70767393..b649baec 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu @@ -660,10 +660,10 @@ private: LayoutO layout_O(ldo_host.at(i)); MatrixCoord extent_Q{problem0.m(), problem0.k()}; - MatrixCoord extent_K{problem0.n(), problem0.k()}; + MatrixCoord extent_K{problem0.k(), problem0.n()}; MatrixCoord extent_P{problem0.m(), problem0.n()}; MatrixCoord extent_V{problem1.k(), problem1.n()}; - MatrixCoord extent_O{problem1.m(), problem1.k()}; + MatrixCoord extent_O{problem1.m(), problem1.n()}; cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); @@ -707,7 +707,6 @@ private: int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); // Compute softmax for referece matrix - // Assumed a row-major storage for (int m = 0; m < problem0.m(); m++) { int n_dim_row = n_dim; if (options.causal) { @@ -737,7 +736,6 @@ private: for (int n = n_dim_row; n < n_dim; ++n) { view_Ref_host.ref().at({m, n}) = ElementP(0); } - } // when not using mask, problem_real and problem share the same sizes @@ -798,7 +796,6 @@ private: return passed; } - } return passed; @@ -808,7 +805,7 @@ public: /// Executes a CUTLASS Attention kernel and measures runtime. - Result profile_grouped() { + Result profile() { Result result; result.passed = false; @@ -886,7 +883,7 @@ public: } // - // Warm-up run of the grouped GEMM object + // Warm-up run // kernel_fn<<>>(p); @@ -975,8 +972,6 @@ public: return result; } - - }; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1002,7 +997,7 @@ int run_attention(Options& options) { TestbedAttention testbed(options); - Result result = testbed.profile_grouped(); + Result result = testbed.profile(); if (!result.passed) { std::cout << "Profiling CUTLASS attention has failed.\n"; std::cout << "\nFailed\n"; diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu index 0738277b..a1d40334 100644 --- a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu @@ -741,7 +741,7 @@ private: LayoutO layout_O(ldo_host.at(i)); MatrixCoord extent_Q{problem0.m(), problem0.k()}; - MatrixCoord extent_K{problem0.n(), problem0.k()}; + MatrixCoord extent_K{problem0.k(), problem0.n()}; MatrixCoord extent_P{problem0.m(), problem0.n()}; MatrixCoord extent_V{problem1.k(), problem1.n()}; MatrixCoord extent_O{problem1.m(), problem1.n()}; @@ -789,7 +789,6 @@ private: int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); // Compute softmax for reference matrix - // Assumed a row-major storage for (int m = 0; m < problem0.m(); m++) { int n_dim_row = n_dim; if (options.causal) { diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h index e6880d31..3f85953b 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -126,7 +126,7 @@ struct AttentionKernel { struct Params { // Input tensors scalar_t* query_ptr; // [num_queries, num_heads, head_dim] - scalar_t* key_ptr; // [num_keys, num_heads, head_dim] + scalar_t* key_ptr; // [num_keys, num_heads, head_dim] scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] int32_t* cu_seqlens_q_ptr = nullptr; int32_t* cu_seqlens_k_ptr = nullptr; @@ -165,6 +165,7 @@ struct AttentionKernel { CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value; } + // Moves pointers to what we should process // Returns "false" if there is no work to do CUTLASS_DEVICE bool advance_to_block() {