Fix spurious re-compilations of rotary_kernel (#911)
All integer parameters are specialized by default, so the two parameters removed in this commit could lead to kernel re-compilation, even if they were completely unused.
This commit is contained in:
parent
23e8fa5a26
commit
f692b98d80
@ -8,15 +8,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_M": 2}),
|
||||
# triton.Config({"BLOCK_M": 4}),
|
||||
# triton.Config({"BLOCK_M": 8}),
|
||||
# triton.Config({"BLOCK_M": 16}),
|
||||
# ],
|
||||
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
|
||||
# )
|
||||
@triton.jit
|
||||
def rotary_kernel(
|
||||
OUT, # Pointers to matrices
|
||||
@ -27,10 +18,8 @@ def rotary_kernel(
|
||||
SEQLEN_OFFSETS, # this could be int or a pointer
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
CACHE_KEY_SEQLEN,
|
||||
# strides
|
||||
stride_out_batch,
|
||||
stride_out_seqlen,
|
||||
@ -218,10 +207,8 @@ def apply_rotary(
|
||||
cu_seqlens,
|
||||
seqlen_offsets,
|
||||
seqlen, # shapes
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
seqlen // 128, # key for triton cache (limit number of compilations)
|
||||
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
|
||||
output.stride(-3), # seqlen_stride or total_seqlen_stride
|
||||
output.stride(-2), # nheads_stride
|
||||
|
||||
@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
|
||||
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
|
||||
assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
|
||||
|
||||
|
||||
def test_compilation_count():
|
||||
batch_size = 1
|
||||
headdim = 128
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
torch.manual_seed(42)
|
||||
|
||||
from triton.runtime.jit import JITFunction
|
||||
from flash_attn.ops.triton.rotary import rotary_kernel
|
||||
compilation_count = 0
|
||||
|
||||
def count_compilations(*args, **kwargs):
|
||||
nonlocal compilation_count
|
||||
compilation_count += 1
|
||||
|
||||
old_cache_func = JITFunction.cache_hook
|
||||
|
||||
try:
|
||||
rotary_kernel.cache.clear()
|
||||
JITFunction.cache_hook = count_compilations
|
||||
|
||||
for seqlen in (128, 256):
|
||||
for nheads in (4, 32):
|
||||
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
|
||||
x.requires_grad_()
|
||||
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
|
||||
out = apply_rotary_emb(x, cos, sin)
|
||||
out.backward(torch.randn_like(out))
|
||||
|
||||
# Only two kernels are expected to be compiled:
|
||||
# * for the forward pass (conjugate=False)
|
||||
# * for the backward pass (conjugate=True)
|
||||
assert compilation_count == 2
|
||||
finally:
|
||||
JITFunction.cache_hook = old_cache_func
|
||||
|
||||
Loading…
Reference in New Issue
Block a user