235 lines
7.4 KiB
Python
235 lines
7.4 KiB
Python
|
|
import torch
|
||
|
|
#from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
|
||
|
|
import flash_attn_interface as fa3
|
||
|
|
import flash_attn as fa2
|
||
|
|
import torch.utils.benchmark as benchmark
|
||
|
|
import time
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import math
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||
|
|
parser.add_argument('--causal', action='store_true')
|
||
|
|
parser.add_argument('--splits', type=int, default=1)
|
||
|
|
parser.add_argument('--repeats', type=int, default=10)
|
||
|
|
parser.add_argument('--validate', action='store_true')
|
||
|
|
parser.add_argument('--gqa', action='store_true')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
def benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs):
|
||
|
|
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
|
||
|
|
if verbose:
|
||
|
|
print(desc, '- Forward pass')
|
||
|
|
t = benchmark.Timer(
|
||
|
|
stmt='fn(**kwinputs)',
|
||
|
|
globals={'fn': fn, 'kwinputs': kwinputs},
|
||
|
|
num_threads=torch.get_num_threads(),
|
||
|
|
)
|
||
|
|
m = t.timeit(repeats)
|
||
|
|
if verbose:
|
||
|
|
print(desc, m)
|
||
|
|
return t, m
|
||
|
|
|
||
|
|
def benchmark_fa_kv(fn, repeats=10, *args, **kwargs):
|
||
|
|
# warmup
|
||
|
|
for _ in range(5):
|
||
|
|
fn(*args, **kwargs)
|
||
|
|
niters = repeats
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
start = time.time()
|
||
|
|
for _ in range(niters):
|
||
|
|
fn(*args, **kwargs)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
end = time.time()
|
||
|
|
return (end - start) / niters
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# *SAMPLE CONFIG*
|
||
|
|
# Model arch params:
|
||
|
|
nheads_q = 64
|
||
|
|
nheads_kv = 8
|
||
|
|
headdim = 128
|
||
|
|
#dtype = torch.bfloat16
|
||
|
|
dtype = torch.float16
|
||
|
|
|
||
|
|
# Cache settings:
|
||
|
|
num_caches = 8
|
||
|
|
cache_seqlen = 1024 * 16
|
||
|
|
|
||
|
|
# Batching settings
|
||
|
|
ntokens = 1024
|
||
|
|
max_queries_per_batch = 4
|
||
|
|
small_request_ntokens = 16
|
||
|
|
|
||
|
|
# Input settings
|
||
|
|
query_seqlens = [900, 12, 1]
|
||
|
|
num_queries = len(query_seqlens)
|
||
|
|
# Need to add empty queries to fill out `max_queries_per_batch`
|
||
|
|
num_padding_queries = max_queries_per_batch - num_queries
|
||
|
|
context_seqlens = [4096, 5120*2, 6145*2]
|
||
|
|
#context_seqlens = [4096, 5120*2, 6152*2]
|
||
|
|
|
||
|
|
# Validation
|
||
|
|
assert sum(query_seqlens) <= ntokens
|
||
|
|
assert all(s < small_request_ntokens for s in query_seqlens[1:])
|
||
|
|
assert num_queries <= max_queries_per_batch
|
||
|
|
assert all(s < cache_seqlen for s in context_seqlens)
|
||
|
|
|
||
|
|
torch.manual_seed(5434)
|
||
|
|
|
||
|
|
# Allocate some tensors
|
||
|
|
k_cache = torch.randn(
|
||
|
|
(num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
|
||
|
|
)
|
||
|
|
v_cache = torch.randn(
|
||
|
|
(num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
|
||
|
|
)
|
||
|
|
|
||
|
|
q_buf_large = torch.randn(
|
||
|
|
(1, ntokens, nheads_q, headdim), device="cuda", dtype=dtype
|
||
|
|
)
|
||
|
|
cache_seqlen_large = torch.tensor(
|
||
|
|
[context_seqlens[0]], dtype=torch.int32, device="cuda"
|
||
|
|
)
|
||
|
|
cache_idx_large = torch.tensor([1], dtype=torch.int32, device="cuda")
|
||
|
|
|
||
|
|
q_buf_small = torch.randn(
|
||
|
|
(max_queries_per_batch - 1, small_request_ntokens, nheads_q, headdim),
|
||
|
|
device="cuda",
|
||
|
|
dtype=dtype,
|
||
|
|
)
|
||
|
|
cache_seqlens_small = torch.tensor(
|
||
|
|
context_seqlens[1:] + [0] * num_padding_queries, dtype=torch.int32, device="cuda"
|
||
|
|
)
|
||
|
|
cache_idxs_small = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[
|
||
|
|
: max_queries_per_batch - 1
|
||
|
|
]
|
||
|
|
|
||
|
|
if args.validate:
|
||
|
|
# Call flash attn
|
||
|
|
# First for the single full-sized query
|
||
|
|
out0, lse0 = fa3.flash_attn_with_kvcache(
|
||
|
|
q=q_buf_large,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlen_large,
|
||
|
|
cache_batch_idx=cache_idx_large,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=args.splits,
|
||
|
|
return_softmax_lse=True,
|
||
|
|
#num_splits=1
|
||
|
|
)
|
||
|
|
|
||
|
|
# Second for n-1 small queries
|
||
|
|
out1_split1, lse1_split1 = fa3.flash_attn_with_kvcache(
|
||
|
|
q=q_buf_small,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlens_small,
|
||
|
|
cache_batch_idx=cache_idxs_small,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=1,
|
||
|
|
gqa_decoding=bool(args.gqa),
|
||
|
|
return_softmax_lse=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Second for n-1 small queries
|
||
|
|
out1, lse1 = fa3.flash_attn_with_kvcache(
|
||
|
|
q=q_buf_small,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlens_small,
|
||
|
|
cache_batch_idx=cache_idxs_small,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=args.splits,
|
||
|
|
gqa_decoding=bool(args.gqa),
|
||
|
|
return_softmax_lse=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Call flash attn
|
||
|
|
# First for the single full-sized query
|
||
|
|
out2 = fa2.flash_attn_with_kvcache(
|
||
|
|
q=q_buf_large,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlen_large,
|
||
|
|
cache_batch_idx=cache_idx_large,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=args.splits,
|
||
|
|
)
|
||
|
|
|
||
|
|
print ('big')
|
||
|
|
print ('diff-max', (out0 - out2).abs().max().item(), cache_seqlens_small)
|
||
|
|
print ('diff-mean', (out0 - out2).abs().mean().item())
|
||
|
|
|
||
|
|
|
||
|
|
# Second for n-1 small queries
|
||
|
|
out3, lse_fa2 = fa2.flash_attn_with_kvcache(
|
||
|
|
q=q_buf_small,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlens_small,
|
||
|
|
cache_batch_idx=cache_idxs_small,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=args.splits,
|
||
|
|
return_softmax_lse=True,
|
||
|
|
#num_splits=1
|
||
|
|
)
|
||
|
|
|
||
|
|
print ('small') #, out1)
|
||
|
|
print ('lse', lse1, lse_fa2, (lse1 - lse_fa2).abs(), out1.shape)
|
||
|
|
print ('lse-dif-max', (lse1 - lse_fa2).abs().max().item())
|
||
|
|
print ('diff-max', (out1 - out3).abs().max().item())
|
||
|
|
print ('diff-mean', (out1 - out3).abs().mean().item())
|
||
|
|
|
||
|
|
|
||
|
|
print ('fa3', args.repeats)
|
||
|
|
time_fa3_big = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,
|
||
|
|
q=q_buf_large,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlen_large,
|
||
|
|
cache_batch_idx=cache_idx_large,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=args.splits,
|
||
|
|
)
|
||
|
|
|
||
|
|
time_fa3_small = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,
|
||
|
|
q=q_buf_small,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlens_small,
|
||
|
|
cache_batch_idx=cache_idxs_small,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=args.splits,
|
||
|
|
)
|
||
|
|
|
||
|
|
print ('fa2 ')
|
||
|
|
|
||
|
|
time_fa2_big = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats,
|
||
|
|
q=q_buf_large,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlen_large,
|
||
|
|
cache_batch_idx=cache_idx_large,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=args.splits
|
||
|
|
)
|
||
|
|
|
||
|
|
time_fa2_small = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats,
|
||
|
|
q=q_buf_small,
|
||
|
|
k_cache=k_cache,
|
||
|
|
v_cache=v_cache,
|
||
|
|
cache_seqlens=cache_seqlens_small,
|
||
|
|
cache_batch_idx=cache_idxs_small,
|
||
|
|
causal=bool(args.causal),
|
||
|
|
num_splits=args.splits
|
||
|
|
)
|
||
|
|
|
||
|
|
print ('big (split, fa3, fa2, ratio):', args.splits, time_fa3_big * 1000000, time_fa2_big * 1000000, time_fa3_big / time_fa2_big)
|
||
|
|
print ('small (split, fa3, fa2, ratio):', args.splits, time_fa3_small * 1000000, time_fa2_small * 1000000, time_fa3_small / time_fa2_small)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|