small fixes
This commit is contained in:
parent
cdbbe844b1
commit
8cbc8a042f
@ -103,10 +103,10 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
|||||||
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
||||||
Return:
|
Return:
|
||||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
||||||
indices: (used_nnz), the indices of non-masked tokens from the flattened input sequence.
|
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
||||||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||||||
max_seqlen_in_batch: int
|
max_seqlen_in_batch: int
|
||||||
seqused: (batch), optionally returns the number of tokens selected in attention_mask + unused_mask if unused_mask is not None.
|
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
|
||||||
"""
|
"""
|
||||||
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
||||||
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
||||||
|
|||||||
@ -1,314 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
try:
|
|
||||||
import cudnn
|
|
||||||
except ImportError:
|
|
||||||
cudnn = None
|
|
||||||
|
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
|
||||||
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
|
||||||
from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
|
|
||||||
|
|
||||||
# Need to install triton nightly:
|
|
||||||
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
|
|
||||||
|
|
||||||
try:
|
|
||||||
from triton_fused_attention import attention as triton_attention
|
|
||||||
except ImportError:
|
|
||||||
triton_attention = None
|
|
||||||
|
|
||||||
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
|
|
||||||
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
|
||||||
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
|
||||||
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_cudnn_type(torch_type):
|
|
||||||
if torch_type == torch.float16:
|
|
||||||
return cudnn.data_type.HALF
|
|
||||||
elif torch_type == torch.bfloat16:
|
|
||||||
return cudnn.data_type.BFLOAT16
|
|
||||||
elif torch_type == torch.float32:
|
|
||||||
return cudnn.data_type.FLOAT
|
|
||||||
elif torch_type == torch.int32:
|
|
||||||
return cudnn.data_type.INT32
|
|
||||||
elif torch_type == torch.int64:
|
|
||||||
return cudnn.data_type.INT64
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported tensor data type.")
|
|
||||||
|
|
||||||
|
|
||||||
def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
|
|
||||||
b, nheads, seqlen_q, headdim = q.shape
|
|
||||||
_, nheads_kv, seqlen_k, _ = k.shape
|
|
||||||
assert v.shape == (b, nheads_kv, seqlen_k, headdim)
|
|
||||||
assert cudnn is not None, 'CUDNN is not available'
|
|
||||||
q_gpu, k_gpu, v_gpu = q, k, v
|
|
||||||
o_gpu, stats_gpu = o, stats
|
|
||||||
graph_forward = cudnn.pygraph(
|
|
||||||
io_data_type=convert_to_cudnn_type(q.dtype),
|
|
||||||
intermediate_data_type=cudnn.data_type.FLOAT,
|
|
||||||
compute_data_type=cudnn.data_type.FLOAT,
|
|
||||||
)
|
|
||||||
q_forward = graph_forward.tensor_like(q_gpu.detach())
|
|
||||||
k_forward = graph_forward.tensor_like(k_gpu.detach())
|
|
||||||
v_forward = graph_forward.tensor_like(v_gpu.detach())
|
|
||||||
|
|
||||||
seqlens_reshaped = seqlens if varlen else None
|
|
||||||
seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
|
||||||
seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
|
||||||
|
|
||||||
o_forward, stats_forward = graph_forward.sdpa(
|
|
||||||
name="sdpa",
|
|
||||||
q=q_forward,
|
|
||||||
k=k_forward,
|
|
||||||
v=v_forward,
|
|
||||||
is_inference=False,
|
|
||||||
attn_scale=1.0 / math.sqrt(headdim),
|
|
||||||
use_causal_mask=causal,
|
|
||||||
use_padding_mask=varlen,
|
|
||||||
seq_len_q=seq_len_q,
|
|
||||||
seq_len_kv=seq_len_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
|
|
||||||
stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)
|
|
||||||
|
|
||||||
graph_forward.validate()
|
|
||||||
graph_forward.build_operation_graph()
|
|
||||||
graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
|
|
||||||
graph_forward.check_support()
|
|
||||||
graph_forward.build_plans()
|
|
||||||
|
|
||||||
variant_pack_forward = {
|
|
||||||
q_forward: q_gpu,
|
|
||||||
k_forward: k_gpu,
|
|
||||||
v_forward: v_gpu,
|
|
||||||
o_forward: o_gpu,
|
|
||||||
stats_forward: stats_gpu,
|
|
||||||
seq_len_q: seqlens_reshaped,
|
|
||||||
seq_len_kv: seqlens_reshaped,
|
|
||||||
}
|
|
||||||
|
|
||||||
dQ_gpu = torch.empty_like(q_gpu)
|
|
||||||
dK_gpu = torch.empty_like(k_gpu)
|
|
||||||
dV_gpu = torch.empty_like(v_gpu)
|
|
||||||
dO_gpu = grad
|
|
||||||
|
|
||||||
graph_backward = cudnn.pygraph(
|
|
||||||
io_data_type=cudnn.data_type.HALF,
|
|
||||||
intermediate_data_type=cudnn.data_type.FLOAT,
|
|
||||||
compute_data_type=cudnn.data_type.FLOAT,
|
|
||||||
)
|
|
||||||
|
|
||||||
q_backward = graph_backward.tensor_like(q_gpu.detach())
|
|
||||||
k_backward = graph_backward.tensor_like(k_gpu.detach())
|
|
||||||
v_backward = graph_backward.tensor_like(v_gpu.detach())
|
|
||||||
o_backward = graph_backward.tensor_like(o_gpu.detach())
|
|
||||||
dO_backward = graph_backward.tensor_like(dO_gpu.detach())
|
|
||||||
stats_backward = graph_backward.tensor_like(stats_gpu.detach())
|
|
||||||
seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
|
||||||
seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
|
||||||
|
|
||||||
dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
|
|
||||||
name="sdpa_backward",
|
|
||||||
q=q_backward,
|
|
||||||
k=k_backward,
|
|
||||||
v=v_backward,
|
|
||||||
o=o_backward,
|
|
||||||
dO=dO_backward,
|
|
||||||
stats=stats_backward,
|
|
||||||
attn_scale=1.0 / math.sqrt(headdim),
|
|
||||||
use_causal_mask=causal,
|
|
||||||
use_padding_mask=varlen,
|
|
||||||
seq_len_q=seq_len_q,
|
|
||||||
seq_len_kv=seq_len_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
|
|
||||||
dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
|
|
||||||
dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
|
|
||||||
|
|
||||||
graph_backward.validate()
|
|
||||||
graph_backward.build_operation_graph()
|
|
||||||
graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
|
|
||||||
graph_backward.check_support()
|
|
||||||
graph_backward.build_plans()
|
|
||||||
|
|
||||||
variant_pack_backward = {
|
|
||||||
q_backward: q_gpu,
|
|
||||||
k_backward: k_gpu,
|
|
||||||
v_backward: v_gpu,
|
|
||||||
o_backward: o_gpu,
|
|
||||||
dO_backward: dO_gpu,
|
|
||||||
stats_backward: stats_gpu,
|
|
||||||
dQ_backward: dQ_gpu,
|
|
||||||
dK_backward: dK_gpu,
|
|
||||||
dV_backward: dV_gpu,
|
|
||||||
seq_len_q: seqlens_reshaped,
|
|
||||||
seq_len_kv: seqlens_reshaped,
|
|
||||||
}
|
|
||||||
|
|
||||||
workspace = torch.empty(
|
|
||||||
max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()),
|
|
||||||
device="cuda", dtype=torch.uint8
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_fwd(*args, **kwargs):
|
|
||||||
graph_forward.execute(variant_pack_forward, workspace)
|
|
||||||
return o_gpu, stats_gpu
|
|
||||||
|
|
||||||
def run_bwd(*args, **kwargs):
|
|
||||||
graph_backward.execute(variant_pack_backward, workspace)
|
|
||||||
return dQ_gpu, dK_gpu, dV_gpu
|
|
||||||
|
|
||||||
return run_fwd, run_bwd
|
|
||||||
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
repeats = 100
|
|
||||||
dropout_p = 0.0
|
|
||||||
causal = False
|
|
||||||
dtype = torch.float16
|
|
||||||
device = 'cuda'
|
|
||||||
verbose = False
|
|
||||||
batch_size = 2
|
|
||||||
# seqlen = 2048
|
|
||||||
seqlen = 8192
|
|
||||||
# seqlen = 4096
|
|
||||||
# seqlen = 2047
|
|
||||||
dim = 2048
|
|
||||||
# headdim = 128
|
|
||||||
# headdim = 64
|
|
||||||
headdim = 256
|
|
||||||
|
|
||||||
for mode in ['fwd', 'bwd']:
|
|
||||||
# for mode in ['bwd']:
|
|
||||||
for headdim in [64, 128, 256]:
|
|
||||||
# for headdim in [128]:
|
|
||||||
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
|
|
||||||
# for seqlen in [8192]:
|
|
||||||
nheads = dim // headdim
|
|
||||||
# nheads = 24
|
|
||||||
# headdim = 64
|
|
||||||
# batch_size = 64
|
|
||||||
# seqlen = 512
|
|
||||||
# nheads = 8
|
|
||||||
# headdim = 128
|
|
||||||
# nheads = 16
|
|
||||||
# headdim = 128
|
|
||||||
nheads_kv = nheads
|
|
||||||
# nheads_kv = 1
|
|
||||||
|
|
||||||
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
|
||||||
requires_grad=True)
|
|
||||||
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
|
||||||
k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
|
|
||||||
v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
|
|
||||||
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
|
|
||||||
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
|
||||||
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
|
||||||
grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
|
|
||||||
grad_t = grad.transpose(1, 2).contiguous()
|
|
||||||
o_t = torch.empty_like(q.transpose(1, 2))
|
|
||||||
stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)
|
|
||||||
|
|
||||||
bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
|
|
||||||
|
|
||||||
for causal in [False, True]:
|
|
||||||
# for causal in [True]:
|
|
||||||
print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
|
|
||||||
# For var-seq-len
|
|
||||||
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
|
|
||||||
seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
|
|
||||||
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
|
|
||||||
if headdim <= 128 and cudnn is not None:
|
|
||||||
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
|
|
||||||
cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
|
|
||||||
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
|
|
||||||
ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
|
|
||||||
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
|
|
||||||
if mode == 'bwd':
|
|
||||||
ref_dv, v.grad = v.grad.clone(), None
|
|
||||||
ref_dk, k.grad = k.grad.clone(), None
|
|
||||||
ref_dq, q.grad = q.grad.clone(), None
|
|
||||||
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
|
|
||||||
if headdim <= 128:
|
|
||||||
if triton_attention is not None and nheads_kv == nheads:
|
|
||||||
if mode == 'fwd':
|
|
||||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
|
||||||
_, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
|
|
||||||
# TODO: fix Triton numeric errors.
|
|
||||||
# if mode == 'bwd':
|
|
||||||
# dv, v_t.grad = v_t.grad.clone(), None
|
|
||||||
# dk, k_t.grad = k_t.grad.clone(), None
|
|
||||||
# dq, q_t.grad = q_t.grad.clone(), None
|
|
||||||
# torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
# torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
# torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
if cudnn is not None:
|
|
||||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
|
||||||
if mode == 'fwd':
|
|
||||||
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
|
||||||
_, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
|
||||||
cudnn_sdpa_fwd()
|
|
||||||
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
cudnn_sdpa_fwd_varlen()
|
|
||||||
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
else:
|
|
||||||
cudnn_sdpa_fwd()
|
|
||||||
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
|
||||||
_, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
|
||||||
dq, dk, dv = cudnn_sdpa_bwd()
|
|
||||||
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
dq, dk, dv = cudnn_sdpa_bwd_varlen()
|
|
||||||
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
||||||
# pytorch_profiler(cudnn_sdpa, backward=False)
|
|
||||||
|
|
||||||
if headdim <= 128 or mode == 'fwd':
|
|
||||||
time.sleep(1)
|
|
||||||
_, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
|
|
||||||
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
|
|
||||||
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
|
|
||||||
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
|
|
||||||
time.sleep(1)
|
|
||||||
if mode == 'bwd':
|
|
||||||
dv, v.grad = v.grad.clone(), None
|
|
||||||
dk, k.grad = k.grad.clone(), None
|
|
||||||
dq, q.grad = q.grad.clone(), None
|
|
||||||
torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
|
|
||||||
torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
|
|
||||||
torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
|
|
||||||
|
|
||||||
bench_var_fn = bench_fn
|
|
||||||
if mode == 'bwd':
|
|
||||||
grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
|
|
||||||
bench_var_fn = partial(benchmark_backward, grad=grad_var)
|
|
||||||
_, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
|
|
||||||
|
|
||||||
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
|
|
||||||
print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
|
|
||||||
if headdim <= 128:
|
|
||||||
if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
|
|
||||||
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
|
|
||||||
if cudnn is not None:
|
|
||||||
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
|
|
||||||
print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
|
|
||||||
if headdim <= 128 or mode == 'fwd':
|
|
||||||
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
|
|
||||||
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
|
|
||||||
|
|
||||||
@ -199,8 +199,8 @@ def test_flash_attn_output(
|
|||||||
# @pytest.mark.parametrize("causal", [False])
|
# @pytest.mark.parametrize("causal", [False])
|
||||||
@pytest.mark.parametrize("deterministic", [False, True])
|
@pytest.mark.parametrize("deterministic", [False, True])
|
||||||
# @pytest.mark.parametrize("deterministic", [False])
|
# @pytest.mark.parametrize("deterministic", [False])
|
||||||
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
|
@pytest.mark.parametrize("add_unused_qkv", [False, True])
|
||||||
@pytest.mark.parametrize("add_unused_qkv", [True])
|
# @pytest.mark.parametrize("add_unused_qkv", [True])
|
||||||
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||||
# @pytest.mark.parametrize('d', [128])
|
# @pytest.mark.parametrize('d', [128])
|
||||||
@ -310,6 +310,7 @@ def test_flash_attn_varlen_output(
|
|||||||
seqused_k=seqused_k,
|
seqused_k=seqused_k,
|
||||||
)
|
)
|
||||||
out = output_pad_fn(out_unpad)
|
out = output_pad_fn(out_unpad)
|
||||||
|
if query_unused_mask is not None:
|
||||||
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
|
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
|
||||||
out.masked_fill_(q_zero_masking, 0.0)
|
out.masked_fill_(q_zero_masking, 0.0)
|
||||||
dropout_mask = None
|
dropout_mask = None
|
||||||
@ -347,6 +348,7 @@ def test_flash_attn_varlen_output(
|
|||||||
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
|
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
|
||||||
dk = dk_pad_fn(dk_unpad)
|
dk = dk_pad_fn(dk_unpad)
|
||||||
dv = dk_pad_fn(dv_unpad)
|
dv = dk_pad_fn(dv_unpad)
|
||||||
|
if key_unused_mask is not None:
|
||||||
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
|
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
|
||||||
dk.masked_fill_(k_zero_masking, 0.0)
|
dk.masked_fill_(k_zero_masking, 0.0)
|
||||||
dv.masked_fill_(k_zero_masking, 0.0)
|
dv.masked_fill_(k_zero_masking, 0.0)
|
||||||
@ -366,6 +368,7 @@ def test_flash_attn_varlen_output(
|
|||||||
dk_pt.masked_fill_(zero_masking, 0.0)
|
dk_pt.masked_fill_(zero_masking, 0.0)
|
||||||
dv_pt.masked_fill_(zero_masking, 0.0)
|
dv_pt.masked_fill_(zero_masking, 0.0)
|
||||||
dq = dq_pad_fn(dq_unpad)
|
dq = dq_pad_fn(dq_unpad)
|
||||||
|
if query_unused_mask is not None:
|
||||||
dq.masked_fill_(q_zero_masking, 0.0)
|
dq.masked_fill_(q_zero_masking, 0.0)
|
||||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||||
|
|||||||
@ -51,7 +51,7 @@ def generate_qkv(
|
|||||||
assert not qkvpacked
|
assert not qkvpacked
|
||||||
|
|
||||||
if query_padding_mask is not None:
|
if query_padding_mask is not None:
|
||||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q, _ = unpad_input(
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
||||||
q, query_padding_mask, query_unused_mask,
|
q, query_padding_mask, query_unused_mask,
|
||||||
)
|
)
|
||||||
output_pad_fn = lambda output_unpad: pad_input(
|
output_pad_fn = lambda output_unpad: pad_input(
|
||||||
@ -69,8 +69,8 @@ def generate_qkv(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k, _ = unpad_input(k, key_padding_mask, key_unused_mask)
|
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask)
|
||||||
v_unpad, _, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)
|
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)
|
||||||
else:
|
else:
|
||||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user