minor chagnes (#730)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
38193d76e3
commit
3f2bb17722
@ -660,10 +660,10 @@ private:
|
|||||||
LayoutO layout_O(ldo_host.at(i));
|
LayoutO layout_O(ldo_host.at(i));
|
||||||
|
|
||||||
MatrixCoord extent_Q{problem0.m(), problem0.k()};
|
MatrixCoord extent_Q{problem0.m(), problem0.k()};
|
||||||
MatrixCoord extent_K{problem0.n(), problem0.k()};
|
MatrixCoord extent_K{problem0.k(), problem0.n()};
|
||||||
MatrixCoord extent_P{problem0.m(), problem0.n()};
|
MatrixCoord extent_P{problem0.m(), problem0.n()};
|
||||||
MatrixCoord extent_V{problem1.k(), problem1.n()};
|
MatrixCoord extent_V{problem1.k(), problem1.n()};
|
||||||
MatrixCoord extent_O{problem1.m(), problem1.k()};
|
MatrixCoord extent_O{problem1.m(), problem1.n()};
|
||||||
|
|
||||||
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
|
cutlass::TensorView<ElementQ, LayoutQ> view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q);
|
||||||
cutlass::TensorView<ElementK, LayoutK> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
|
cutlass::TensorView<ElementK, LayoutK> view_K(block_K.get() + offset_K.at(i), layout_K, extent_K);
|
||||||
@ -707,7 +707,6 @@ private:
|
|||||||
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
|
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
|
||||||
|
|
||||||
// Compute softmax for referece matrix
|
// Compute softmax for referece matrix
|
||||||
// Assumed a row-major storage
|
|
||||||
for (int m = 0; m < problem0.m(); m++) {
|
for (int m = 0; m < problem0.m(); m++) {
|
||||||
int n_dim_row = n_dim;
|
int n_dim_row = n_dim;
|
||||||
if (options.causal) {
|
if (options.causal) {
|
||||||
@ -737,7 +736,6 @@ private:
|
|||||||
for (int n = n_dim_row; n < n_dim; ++n) {
|
for (int n = n_dim_row; n < n_dim; ++n) {
|
||||||
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
view_Ref_host.ref().at({m, n}) = ElementP(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// when not using mask, problem_real and problem share the same sizes
|
// when not using mask, problem_real and problem share the same sizes
|
||||||
@ -798,7 +796,6 @@ private:
|
|||||||
|
|
||||||
return passed;
|
return passed;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return passed;
|
return passed;
|
||||||
@ -808,7 +805,7 @@ public:
|
|||||||
|
|
||||||
|
|
||||||
/// Executes a CUTLASS Attention kernel and measures runtime.
|
/// Executes a CUTLASS Attention kernel and measures runtime.
|
||||||
Result profile_grouped() {
|
Result profile() {
|
||||||
|
|
||||||
Result result;
|
Result result;
|
||||||
result.passed = false;
|
result.passed = false;
|
||||||
@ -886,7 +883,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Warm-up run of the grouped GEMM object
|
// Warm-up run
|
||||||
//
|
//
|
||||||
|
|
||||||
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
|
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
|
||||||
@ -975,8 +972,6 @@ public:
|
|||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -1002,7 +997,7 @@ int run_attention(Options& options) {
|
|||||||
|
|
||||||
TestbedAttention<Attention> testbed(options);
|
TestbedAttention<Attention> testbed(options);
|
||||||
|
|
||||||
Result result = testbed.profile_grouped();
|
Result result = testbed.profile();
|
||||||
if (!result.passed) {
|
if (!result.passed) {
|
||||||
std::cout << "Profiling CUTLASS attention has failed.\n";
|
std::cout << "Profiling CUTLASS attention has failed.\n";
|
||||||
std::cout << "\nFailed\n";
|
std::cout << "\nFailed\n";
|
||||||
|
@ -741,7 +741,7 @@ private:
|
|||||||
LayoutO layout_O(ldo_host.at(i));
|
LayoutO layout_O(ldo_host.at(i));
|
||||||
|
|
||||||
MatrixCoord extent_Q{problem0.m(), problem0.k()};
|
MatrixCoord extent_Q{problem0.m(), problem0.k()};
|
||||||
MatrixCoord extent_K{problem0.n(), problem0.k()};
|
MatrixCoord extent_K{problem0.k(), problem0.n()};
|
||||||
MatrixCoord extent_P{problem0.m(), problem0.n()};
|
MatrixCoord extent_P{problem0.m(), problem0.n()};
|
||||||
MatrixCoord extent_V{problem1.k(), problem1.n()};
|
MatrixCoord extent_V{problem1.k(), problem1.n()};
|
||||||
MatrixCoord extent_O{problem1.m(), problem1.n()};
|
MatrixCoord extent_O{problem1.m(), problem1.n()};
|
||||||
@ -789,7 +789,6 @@ private:
|
|||||||
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
|
int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n();
|
||||||
|
|
||||||
// Compute softmax for reference matrix
|
// Compute softmax for reference matrix
|
||||||
// Assumed a row-major storage
|
|
||||||
for (int m = 0; m < problem0.m(); m++) {
|
for (int m = 0; m < problem0.m(); m++) {
|
||||||
int n_dim_row = n_dim;
|
int n_dim_row = n_dim;
|
||||||
if (options.causal) {
|
if (options.causal) {
|
||||||
|
@ -126,7 +126,7 @@ struct AttentionKernel {
|
|||||||
struct Params {
|
struct Params {
|
||||||
// Input tensors
|
// Input tensors
|
||||||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
||||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
||||||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
||||||
int32_t* cu_seqlens_q_ptr = nullptr;
|
int32_t* cu_seqlens_q_ptr = nullptr;
|
||||||
int32_t* cu_seqlens_k_ptr = nullptr;
|
int32_t* cu_seqlens_k_ptr = nullptr;
|
||||||
@ -165,6 +165,7 @@ struct AttentionKernel {
|
|||||||
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
||||||
return head_dim_value;
|
return head_dim_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Moves pointers to what we should process
|
// Moves pointers to what we should process
|
||||||
// Returns "false" if there is no work to do
|
// Returns "false" if there is no work to do
|
||||||
CUTLASS_DEVICE bool advance_to_block() {
|
CUTLASS_DEVICE bool advance_to_block() {
|
||||||
|
Loading…
Reference in New Issue
Block a user