[LoRA][Kernel] Remove the unused libentry module (#10214)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-11 17:43:23 +08:00 committed by GitHub
parent 58170d6503
commit 36e4acd02a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 49 additions and 276 deletions

View File

@ -4,8 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
from unittest.mock import patch
import pytest
import torch
@ -16,7 +14,6 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
@ -235,9 +232,6 @@ def test_punica_bgmv(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device)
current_platform.seed_everything(seed)
@ -262,33 +256,21 @@ def test_punica_bgmv(
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
@ -324,7 +306,6 @@ def test_punica_expand_nslices(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device)
current_platform.seed_everything(seed)
@ -374,22 +355,16 @@ def test_punica_expand_nslices(
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,

View File

@ -3,8 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
from unittest.mock import patch
import pytest
import torch
@ -15,7 +13,6 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
@ -150,8 +147,6 @@ def test_punica_bgmv(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device)
current_platform.seed_everything(seed)
@ -177,33 +172,22 @@ def test_punica_bgmv(
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
@ -239,8 +223,6 @@ def test_punica_expand_nslices(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device)
current_platform.seed_everything(seed)
@ -289,22 +271,15 @@ def test_punica_expand_nslices(
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,

View File

@ -9,10 +9,7 @@ import torch
import triton
import triton.language as tl
from vllm.triton_utils import libentry
@libentry()
@triton.jit
def _sgmv_expand_kernel(
input_ptr,

View File

@ -9,10 +9,7 @@ import torch
import triton
import triton.language as tl
from vllm.triton_utils import libentry
@libentry()
@triton.jit
def _sgmv_expand_slice_kernel(
input_ptr,

View File

@ -9,10 +9,7 @@ import torch
import triton
import triton.language as tl
from vllm.triton_utils import libentry
@libentry()
@triton.jit
def _sgmv_shrink_kernel(
input_ptr,

View File

@ -6,6 +6,5 @@ if HAS_TRITON:
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)
from vllm.triton_utils.libentry import libentry
__all__ += ["maybe_set_triton_cache_manager", "libentry"]
__all__ += ["maybe_set_triton_cache_manager"]

View File

@ -1,167 +0,0 @@
# Copied From https://github.com/FlagOpen/FlagGems
import inspect
import triton
class LibEntry(triton.KernelInterface):
def __init__(
self,
fn,
):
self.fn = fn
self.arg_names = fn.arg_names
self.divisibility = 16
self.kernel_cache = dict()
fn = self.fn
while not isinstance(fn, triton.runtime.JITFunction):
fn = fn.fn
self.jit_function: triton.runtime.JITFunction = fn
self.specialize_indices = [
p.num for p in self.jit_function.params
if not p.is_constexpr and not p.do_not_specialize
]
self.do_not_specialize_indices = [
p.num for p in self.jit_function.params
if not p.is_constexpr and p.do_not_specialize
]
def key(self, spec_args, dns_args, const_args):
spec_key = [(arg.dtype, arg.data_ptr() %
self.divisibility == 0) if hasattr(arg, "data_ptr") else
(type(arg), arg) for arg in spec_args]
dns_key = [
arg.dtype if hasattr(
arg, "data_ptr") else type(arg) if not isinstance(arg, int)
else "i32" if arg >= -(2**31) and arg <= 2**31 -
1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
for arg in dns_args
]
# const args passed by position
return tuple(spec_key + dns_key + const_args)
def run(self, *args, **kwargs):
grid = kwargs["grid"]
# collect all the arguments
spec_args = [] # specialize arguments
dns_args = [] # do not specialize arguments
const_args = [] # constexpr arguments
k_args = [] # kernel arguments
for i, arg in enumerate(args):
if i in self.specialize_indices:
k_args.append(arg)
spec_args.append(arg)
elif i in self.do_not_specialize_indices:
k_args.append(arg)
dns_args.append(arg)
else:
const_args.append(arg)
for p in self.jit_function.params[len(args):]:
if p.name in kwargs:
val = kwargs[p.name]
elif p.default is inspect._empty:
continue
else:
val = p.default
if p.is_constexpr:
const_args.append(val)
elif p.do_not_specialize:
dns_args.append(val)
k_args.append(val)
else:
spec_args.append(val)
k_args.append(val)
entry_key = self.key(spec_args, dns_args, const_args)
if entry_key not in self.kernel_cache:
# compile the kernel also completes the related computations
kernel = self.fn.run(*args, **kwargs)
fn = self.fn
# collect constexpr arguments for grid computation
constexprs = {}
while not isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, triton.runtime.Autotuner):
config = fn.best_config
constexprs["num_warps"] = config.num_warps
constexprs["num_stages"] = config.num_stages
constexprs["num_ctas"] = config.num_ctas
constexprs = {**constexprs, **config.kwargs}
elif isinstance(fn, triton.runtime.Heuristics):
for v, heur in fn.values.items():
constexprs[v] = heur({
**dict(zip(fn.arg_names, args)),
**kwargs,
**constexprs,
})
else:
raise RuntimeError("Invalid Runtime Function")
fn = fn.fn
# In vLLM, certain kernels like fused_moe_kernel get the
# best_config(as kwargs) from a configuration json file, rather
# than using Autotuner & Heuristics. Therefore, all their constexprs
# (tl.constexpr) are assigned values through the following loop.
for p in self.jit_function.params:
if p.is_constexpr and p.name not in constexprs:
constexprs[p.name] = p.default #default=inspect._empty
self.kernel_cache[entry_key] = (kernel, constexprs)
else:
# load kernel from cache directly
kernel, constexprs = self.kernel_cache[entry_key]
if callable(grid):
# collect all arguments to the grid fnie:
# 1. args,
# 2. kwargs,
# 3. all all other captured arguments in CompiledKernel from
# Autotunner & Heuristics when kwargs & captured args conflict,
# captured args have higher priority
# 4. We must filter out captured args with default value firstly
constexprs = {
k: v
for k, v in constexprs.items() if v is not inspect._empty
}
meta = {
**dict(zip(self.arg_names, args)),
**kwargs,
**constexprs,
}
grid = grid(meta)
if isinstance(grid, tuple):
grid = grid + (1, 1)
elif isinstance(grid, list):
grid = grid + [1, 1]
kernel[grid[0:3]](*k_args)
# maintaining the same return type as the JITFunction.run
return kernel
def libentry():
"""
Decorator for triton library entries.
Motivation:
The runtime overhead of Triton kernels is the reason for the lower
performance of small kernels, particularly evident with smaller models.
Using this decorator can reduce Triton runtime overhead.
How:
The `run` function of JITFunction needs to accomplish:
- Parameter binding using inspect
- KernelArg type wrapping
- Cache key calculation
When dealing with small size, these steps can become bottlenecks in
Triton runtime. Libentry simplifies these steps to reduce runtime
overhead, thereby improving the runtime expenses of small kernels.
NOTE:
When Triton is upgraded to version 3.0.0, libentry can be removed,
see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
"""
def decorator(fn):
return LibEntry(fn)
return decorator