flash-attention/hopper/test_kvcache.py

235 lines
7.4 KiB
Python
Raw Normal View History

import torch
#from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
import flash_attn_interface as fa3
import flash_attn as fa2
import torch.utils.benchmark as benchmark
import time
import argparse
import math
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--causal', action='store_true')
parser.add_argument('--splits', type=int, default=1)
parser.add_argument('--repeats', type=int, default=10)
parser.add_argument('--validate', action='store_true')
parser.add_argument('--gqa', action='store_true')
args = parser.parse_args()
def benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose:
print(desc, '- Forward pass')
t = benchmark.Timer(
stmt='fn(**kwinputs)',
globals={'fn': fn, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(desc, m)
return t, m
def benchmark_fa_kv(fn, repeats=10, *args, **kwargs):
# warmup
for _ in range(5):
fn(*args, **kwargs)
niters = repeats
torch.cuda.synchronize()
start = time.time()
for _ in range(niters):
fn(*args, **kwargs)
torch.cuda.synchronize()
end = time.time()
return (end - start) / niters
def main():
# *SAMPLE CONFIG*
# Model arch params:
nheads_q = 64
nheads_kv = 8
headdim = 128
#dtype = torch.bfloat16
dtype = torch.float16
# Cache settings:
num_caches = 8
cache_seqlen = 1024 * 16
# Batching settings
ntokens = 1024
max_queries_per_batch = 4
small_request_ntokens = 16
# Input settings
query_seqlens = [900, 12, 1]
num_queries = len(query_seqlens)
# Need to add empty queries to fill out `max_queries_per_batch`
num_padding_queries = max_queries_per_batch - num_queries
context_seqlens = [4096, 5120*2, 6145*2]
#context_seqlens = [4096, 5120*2, 6152*2]
# Validation
assert sum(query_seqlens) <= ntokens
assert all(s < small_request_ntokens for s in query_seqlens[1:])
assert num_queries <= max_queries_per_batch
assert all(s < cache_seqlen for s in context_seqlens)
torch.manual_seed(5434)
# Allocate some tensors
k_cache = torch.randn(
(num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
)
v_cache = torch.randn(
(num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
)
q_buf_large = torch.randn(
(1, ntokens, nheads_q, headdim), device="cuda", dtype=dtype
)
cache_seqlen_large = torch.tensor(
[context_seqlens[0]], dtype=torch.int32, device="cuda"
)
cache_idx_large = torch.tensor([1], dtype=torch.int32, device="cuda")
q_buf_small = torch.randn(
(max_queries_per_batch - 1, small_request_ntokens, nheads_q, headdim),
device="cuda",
dtype=dtype,
)
cache_seqlens_small = torch.tensor(
context_seqlens[1:] + [0] * num_padding_queries, dtype=torch.int32, device="cuda"
)
cache_idxs_small = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[
: max_queries_per_batch - 1
]
if args.validate:
# Call flash attn
# First for the single full-sized query
out0, lse0 = fa3.flash_attn_with_kvcache(
q=q_buf_large,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlen_large,
cache_batch_idx=cache_idx_large,
causal=bool(args.causal),
num_splits=args.splits,
return_softmax_lse=True,
#num_splits=1
)
# Second for n-1 small queries
out1_split1, lse1_split1 = fa3.flash_attn_with_kvcache(
q=q_buf_small,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens_small,
cache_batch_idx=cache_idxs_small,
causal=bool(args.causal),
num_splits=1,
gqa_decoding=bool(args.gqa),
return_softmax_lse=True,
)
# Second for n-1 small queries
out1, lse1 = fa3.flash_attn_with_kvcache(
q=q_buf_small,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens_small,
cache_batch_idx=cache_idxs_small,
causal=bool(args.causal),
num_splits=args.splits,
gqa_decoding=bool(args.gqa),
return_softmax_lse=True,
)
# Call flash attn
# First for the single full-sized query
out2 = fa2.flash_attn_with_kvcache(
q=q_buf_large,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlen_large,
cache_batch_idx=cache_idx_large,
causal=bool(args.causal),
num_splits=args.splits,
)
print ('big')
print ('diff-max', (out0 - out2).abs().max().item(), cache_seqlens_small)
print ('diff-mean', (out0 - out2).abs().mean().item())
# Second for n-1 small queries
out3, lse_fa2 = fa2.flash_attn_with_kvcache(
q=q_buf_small,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens_small,
cache_batch_idx=cache_idxs_small,
causal=bool(args.causal),
num_splits=args.splits,
return_softmax_lse=True,
#num_splits=1
)
print ('small') #, out1)
print ('lse', lse1, lse_fa2, (lse1 - lse_fa2).abs(), out1.shape)
print ('lse-dif-max', (lse1 - lse_fa2).abs().max().item())
print ('diff-max', (out1 - out3).abs().max().item())
print ('diff-mean', (out1 - out3).abs().mean().item())
print ('fa3', args.repeats)
time_fa3_big = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,
q=q_buf_large,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlen_large,
cache_batch_idx=cache_idx_large,
causal=bool(args.causal),
num_splits=args.splits,
)
time_fa3_small = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,
q=q_buf_small,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens_small,
cache_batch_idx=cache_idxs_small,
causal=bool(args.causal),
num_splits=args.splits,
)
print ('fa2 ')
time_fa2_big = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats,
q=q_buf_large,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlen_large,
cache_batch_idx=cache_idx_large,
causal=bool(args.causal),
num_splits=args.splits
)
time_fa2_small = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats,
q=q_buf_small,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens_small,
cache_batch_idx=cache_idxs_small,
causal=bool(args.causal),
num_splits=args.splits
)
print ('big (split, fa3, fa2, ratio):', args.splits, time_fa3_big * 1000000, time_fa2_big * 1000000, time_fa3_big / time_fa2_big)
print ('small (split, fa3, fa2, ratio):', args.splits, time_fa3_small * 1000000, time_fa2_small * 1000000, time_fa3_small / time_fa2_small)
if __name__ == "__main__":
main()