From f692b98d805850983f14deec7a9104583c58b107 Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Fri, 5 Apr 2024 22:40:41 +0200 Subject: [PATCH] Fix spurious re-compilations of `rotary_kernel` (#911) All integer parameters are specialized by default, so the two parameters removed in this commit could lead to kernel re-compilation, even if they were completely unused. --- flash_attn/ops/triton/rotary.py | 13 ------------ tests/test_rotary.py | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index 8d2e09b..6c04a52 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -8,15 +8,6 @@ import triton import triton.language as tl -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 2}), -# triton.Config({"BLOCK_M": 4}), -# triton.Config({"BLOCK_M": 8}), -# triton.Config({"BLOCK_M": 16}), -# ], -# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], -# ) @triton.jit def rotary_kernel( OUT, # Pointers to matrices @@ -27,10 +18,8 @@ def rotary_kernel( SEQLEN_OFFSETS, # this could be int or a pointer # Matrix dimensions seqlen, - nheads, rotary_dim, seqlen_ro, - CACHE_KEY_SEQLEN, # strides stride_out_batch, stride_out_seqlen, @@ -218,10 +207,8 @@ def apply_rotary( cu_seqlens, seqlen_offsets, seqlen, # shapes - nheads, rotary_dim, seqlen_ro, - seqlen // 128, # key for triton cache (limit number of compilations) output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 output.stride(-3), # seqlen_stride or total_seqlen_stride output.stride(-2), # nheads_stride diff --git a/tests/test_rotary.py b/tests/test_rotary.py index 574d052..6f2a5fa 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol) + + +def test_compilation_count(): + batch_size = 1 + headdim = 128 + device = "cuda" + dtype = torch.float16 + torch.manual_seed(42) + + from triton.runtime.jit import JITFunction + from flash_attn.ops.triton.rotary import rotary_kernel + compilation_count = 0 + + def count_compilations(*args, **kwargs): + nonlocal compilation_count + compilation_count += 1 + + old_cache_func = JITFunction.cache_hook + + try: + rotary_kernel.cache.clear() + JITFunction.cache_hook = count_compilations + + for seqlen in (128, 256): + for nheads in (4, 32): + x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device) + x.requires_grad_() + cos, sin = generate_cos_sin(seqlen, headdim, device, dtype) + out = apply_rotary_emb(x, cos, sin) + out.backward(torch.randn_like(out)) + + # Only two kernels are expected to be compiled: + # * for the forward pass (conjugate=False) + # * for the backward pass (conjugate=True) + assert compilation_count == 2 + finally: + JITFunction.cache_hook = old_cache_func