Fix typos of comments about shape. (#837)
This commit is contained in:
parent
0d810cfb73
commit
9486635c92
@ -31,7 +31,7 @@ struct Alibi {
|
||||
const int col_idx_offset_,
|
||||
const int row_idx_offset,
|
||||
const int warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
|
||||
@ -471,7 +471,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
|
||||
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
|
||||
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread(32, 0)) { print(scores); }
|
||||
|
||||
@ -565,7 +565,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
|
||||
);
|
||||
|
||||
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
|
||||
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
|
||||
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
|
||||
auto pointwise_mult = [](float p, float dp, float d) {
|
||||
return p * (!Is_dropout || p >= 0 ? dp - d : d);
|
||||
|
||||
@ -13,7 +13,7 @@ using namespace cute;
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
|
||||
Loading…
Reference in New Issue
Block a user