[Bugfix] Try to handle older versions of pytorch (#9086)
This commit is contained in:
parent
de24046fcd
commit
bd37b9fbe2
@ -1,11 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm import _custom_ops as ops # noqa: F401
|
from vllm import _custom_ops as ops # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
|
||||||
|
reason="AWQ is not supported on this GPU type.")
|
||||||
def test_awq_dequantize_opcheck():
|
def test_awq_dequantize_opcheck():
|
||||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||||
qweight = torch.randint(-2000000000,
|
qweight = torch.randint(-2000000000,
|
||||||
@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
|
|||||||
(qweight, scales, zeros, split_k_iters, thx, thy))
|
(qweight, scales, zeros, split_k_iters, thx, thy))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
|
||||||
|
reason="AWQ is not supported on this GPU type.")
|
||||||
def test_awq_gemm_opcheck():
|
def test_awq_gemm_opcheck():
|
||||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
||||||
torch_moe_single)
|
torch_moe_single)
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
fused_marlin_moe, single_marlin_moe)
|
fused_marlin_moe, single_marlin_moe)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
@ -21,6 +22,9 @@ from vllm.scalar_type import scalar_types
|
|||||||
@pytest.mark.parametrize("e", [8, 64])
|
@pytest.mark.parametrize("e", [8, 64])
|
||||||
@pytest.mark.parametrize("topk", [2, 6])
|
@pytest.mark.parametrize("topk", [2, 6])
|
||||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||||
|
@pytest.mark.skipif(not (ops.supports_moe_ops
|
||||||
|
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
|
||||||
|
reason="Marlin is not supported on this GPU type.")
|
||||||
def test_fused_marlin_moe_awq(
|
def test_fused_marlin_moe_awq(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.library
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm._core_ext import ScalarType
|
from vllm._core_ext import ScalarType
|
||||||
@ -25,6 +26,16 @@ with contextlib.suppress(ImportError):
|
|||||||
import vllm._moe_C # noqa: F401
|
import vllm._moe_C # noqa: F401
|
||||||
supports_moe_ops = True
|
supports_moe_ops = True
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def register_fake(fn):
|
||||||
|
return lambda name: fn
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from torch.library import register_fake
|
||||||
|
except ImportError:
|
||||||
|
from torch.library import impl_abstract as register_fake
|
||||||
|
|
||||||
|
|
||||||
def hint_on_error(fn):
|
def hint_on_error(fn):
|
||||||
|
|
||||||
@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
|
|
||||||
if hasattr(torch.ops._C, "gptq_gemm"):
|
if hasattr(torch.ops._C, "gptq_gemm"):
|
||||||
|
|
||||||
@torch.library.register_fake("_C::gptq_gemm")
|
@register_fake("_C::gptq_gemm")
|
||||||
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
b_gptq_qzeros: torch.Tensor,
|
b_gptq_qzeros: torch.Tensor,
|
||||||
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
|
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
|
||||||
@ -301,7 +312,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
|
|
||||||
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||||
|
|
||||||
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
|
@register_fake("_C::gptq_marlin_24_gemm")
|
||||||
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
||||||
workspace: torch.Tensor,
|
workspace: torch.Tensor,
|
||||||
@ -309,7 +320,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
size_n: int, size_k: int) -> torch.Tensor:
|
size_n: int, size_k: int) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::gptq_marlin_gemm")
|
@register_fake("_C::gptq_marlin_gemm")
|
||||||
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
||||||
b_q_weight: torch.Tensor,
|
b_q_weight: torch.Tensor,
|
||||||
b_scales: torch.Tensor,
|
b_scales: torch.Tensor,
|
||||||
@ -326,12 +337,12 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
use_fp32_reduce: bool = False) -> torch.Tensor:
|
use_fp32_reduce: bool = False) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::ggml_dequantize")
|
@register_fake("_C::ggml_dequantize")
|
||||||
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
|
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
|
||||||
n: int) -> torch.Tensor:
|
n: int) -> torch.Tensor:
|
||||||
return torch.empty((m, n), dtype=torch.float16, device=W.device)
|
return torch.empty((m, n), dtype=torch.float16, device=W.device)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
|
@register_fake("_C::ggml_mul_mat_vec_a8")
|
||||||
def _ggml_mul_mat_vec_a8_fake(
|
def _ggml_mul_mat_vec_a8_fake(
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
@ -340,7 +351,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty((1, row), dtype=torch.float16, device=W.device)
|
return torch.empty((1, row), dtype=torch.float16, device=W.device)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::ggml_mul_mat_a8")
|
@register_fake("_C::ggml_mul_mat_a8")
|
||||||
def _ggml_mul_mat_a8_fake(
|
def _ggml_mul_mat_a8_fake(
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
@ -350,7 +361,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
batch = X.size(0)
|
batch = X.size(0)
|
||||||
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
|
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::marlin_qqq_gemm")
|
@register_fake("_C::marlin_qqq_gemm")
|
||||||
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
||||||
s_group: torch.Tensor, workspace: torch.Tensor,
|
s_group: torch.Tensor, workspace: torch.Tensor,
|
||||||
@ -360,7 +371,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
device=a.device)
|
device=a.device)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::marlin_gemm")
|
@register_fake("_C::marlin_gemm")
|
||||||
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||||
size_m: int, size_n: int,
|
size_m: int, size_n: int,
|
||||||
@ -369,7 +380,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
device=a.device)
|
device=a.device)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::awq_dequantize")
|
@register_fake("_C::awq_dequantize")
|
||||||
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
|
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
|
||||||
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
||||||
thy: int) -> torch.Tensor:
|
thy: int) -> torch.Tensor:
|
||||||
@ -380,7 +391,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
dtype=scales.dtype,
|
dtype=scales.dtype,
|
||||||
device=scales.device)
|
device=scales.device)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::awq_gemm")
|
@register_fake("_C::awq_gemm")
|
||||||
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
|
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
|
||||||
qzeros: torch.Tensor, scales: torch.Tensor,
|
qzeros: torch.Tensor, scales: torch.Tensor,
|
||||||
split_k_iters: int) -> torch.Tensor:
|
split_k_iters: int) -> torch.Tensor:
|
||||||
@ -389,7 +400,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
dtype=input.dtype,
|
dtype=input.dtype,
|
||||||
device=input.device).sum(0)
|
device=input.device).sum(0)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::aqlm_gemm")
|
@register_fake("_C::aqlm_gemm")
|
||||||
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
|
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
|
||||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||||
codebook_partition_sizes: List[int],
|
codebook_partition_sizes: List[int],
|
||||||
@ -405,7 +416,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
output_sizes.append(-1)
|
output_sizes.append(-1)
|
||||||
return flat_output.reshape(tuple(output_sizes))
|
return flat_output.reshape(tuple(output_sizes))
|
||||||
|
|
||||||
@torch.library.register_fake("_C::aqlm_dequant")
|
@register_fake("_C::aqlm_dequant")
|
||||||
def _aqlm_dequant_fake(
|
def _aqlm_dequant_fake(
|
||||||
codes: torch.Tensor, codebooks: torch.Tensor,
|
codes: torch.Tensor, codebooks: torch.Tensor,
|
||||||
codebook_partition_sizes: List[int]) -> torch.Tensor:
|
codebook_partition_sizes: List[int]) -> torch.Tensor:
|
||||||
@ -415,14 +426,14 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
dtype=codebooks.dtype,
|
dtype=codebooks.dtype,
|
||||||
device=codebooks.device)
|
device=codebooks.device)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::fp8_marlin_gemm")
|
@register_fake("_C::fp8_marlin_gemm")
|
||||||
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||||
num_bits: int, size_m: int, size_n: int,
|
num_bits: int, size_m: int, size_n: int,
|
||||||
size_k: int) -> torch.Tensor:
|
size_k: int) -> torch.Tensor:
|
||||||
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::machete_gemm")
|
@register_fake("_C::machete_gemm")
|
||||||
def machete_gemm_fake(
|
def machete_gemm_fake(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
# Should be the tensor returned by machete_prepack_B
|
# Should be the tensor returned by machete_prepack_B
|
||||||
@ -440,13 +451,13 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
n = b_q.size(1)
|
n = b_q.size(1)
|
||||||
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::machete_prepack_B")
|
@register_fake("_C::machete_prepack_B")
|
||||||
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
|
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
|
||||||
b_type: ScalarType) -> torch.Tensor:
|
b_type: ScalarType) -> torch.Tensor:
|
||||||
return torch.empty_like(b_q_weight,
|
return torch.empty_like(b_q_weight,
|
||||||
memory_format=torch.contiguous_format)
|
memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::causal_conv1d_fwd")
|
@register_fake("_C::causal_conv1d_fwd")
|
||||||
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
|
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
|
||||||
bias_: Optional[torch.Tensor],
|
bias_: Optional[torch.Tensor],
|
||||||
conv_states: Optional[torch.Tensor],
|
conv_states: Optional[torch.Tensor],
|
||||||
@ -456,7 +467,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
silu_activation: bool) -> torch.Tensor:
|
silu_activation: bool) -> torch.Tensor:
|
||||||
return torch.empty_like(x)
|
return torch.empty_like(x)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::causal_conv1d_update")
|
@register_fake("_C::causal_conv1d_update")
|
||||||
def causal_conv1d_update_fake(
|
def causal_conv1d_update_fake(
|
||||||
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
|
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
|
||||||
bias_: Optional[torch.Tensor], silu_activation: bool,
|
bias_: Optional[torch.Tensor], silu_activation: bool,
|
||||||
@ -464,7 +475,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
|
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
return torch.empty_like(x)
|
return torch.empty_like(x)
|
||||||
|
|
||||||
@torch.library.register_fake("_C::selective_scan_fwd")
|
@register_fake("_C::selective_scan_fwd")
|
||||||
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
|
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
|
||||||
A: torch.Tensor, B: torch.Tensor,
|
A: torch.Tensor, B: torch.Tensor,
|
||||||
C: torch.Tensor, D_: Optional[torch.Tensor],
|
C: torch.Tensor, D_: Optional[torch.Tensor],
|
||||||
@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
|
|||||||
|
|
||||||
if hasattr(torch.ops._C, "permute_cols"):
|
if hasattr(torch.ops._C, "permute_cols"):
|
||||||
|
|
||||||
@torch.library.register_fake("_C::permute_cols")
|
@register_fake("_C::permute_cols")
|
||||||
def _permute_cols_fake(a: torch.Tensor,
|
def _permute_cols_fake(a: torch.Tensor,
|
||||||
perm: torch.Tensor) -> torch.Tensor:
|
perm: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.empty_like(a)
|
return torch.empty_like(a)
|
||||||
@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
|||||||
|
|
||||||
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||||
|
|
||||||
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
|
@register_fake("_moe_C::marlin_gemm_moe")
|
||||||
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
|
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
|
||||||
sorted_ids: torch.Tensor,
|
sorted_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user