Move masking to a separate file (mask.h)
This commit is contained in:
parent
9448264ddd
commit
6777336a1c
@ -1,3 +1,7 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "philox.cuh"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
@ -14,6 +14,7 @@
|
||||
#include "kernel_traits.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
#include "mask.h"
|
||||
#include "dropout.h"
|
||||
|
||||
#include "alibi.h"
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
@ -14,6 +14,7 @@
|
||||
#include "kernel_traits.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
#include "mask.h"
|
||||
#include "dropout.h"
|
||||
|
||||
#include "alibi.h"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
110
csrc/flash_attn/src/mask.h
Normal file
110
csrc/flash_attn/src/mask.h
Normal file
@ -0,0 +1,110 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||
// print(tensor(make_coord(i, mi), _));
|
||||
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
|
||||
max_seqlen_q, warp_row_stride, -1, 0);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
tensor(mi, ni) = -INFINITY;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||
// print(tensor(_, make_coord(j, ni)));
|
||||
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
@ -115,101 +115,4 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||
// print(tensor(make_coord(i, mi), _));
|
||||
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
|
||||
max_seqlen_q, warp_row_stride, -1, 0);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
tensor(mi, ni) = -INFINITY;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||
// print(tensor(_, make_coord(j, ni)));
|
||||
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
Loading…
Reference in New Issue
Block a user