[Bugfix] Allow vllm to still work if triton is not installed. (#6786)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
7f8d612d24
commit
9a7e2d0534
@ -4,4 +4,3 @@
|
|||||||
# Dependencies for x86_64 CPUs
|
# Dependencies for x86_64 CPUs
|
||||||
torch == 2.4.0; platform_machine != "ppc64le"
|
torch == 2.4.0; platform_machine != "ppc64le"
|
||||||
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
|
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
|
||||||
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
|
||||||
|
|||||||
@ -5,5 +5,3 @@
|
|||||||
torch >= 2.1.2
|
torch >= 2.1.2
|
||||||
openvino ~= 2024.3.0.dev
|
openvino ~= 2024.3.0.dev
|
||||||
optimum-intel[openvino] >= 1.18.1
|
optimum-intel[openvino] >= 1.18.1
|
||||||
|
|
||||||
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
|
||||||
|
|||||||
@ -5,4 +5,3 @@
|
|||||||
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
||||||
# You can install the dependencies in Dockerfile.tpu.
|
# You can install the dependencies in Dockerfile.tpu.
|
||||||
ray
|
ray
|
||||||
triton # To avoid import errors
|
|
||||||
|
|||||||
@ -5,11 +5,12 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm.model_executor.layers.ops.sample import (
|
from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential,
|
||||||
MAX_TRITON_N_COLS, _uniform_to_exponential, get_num_triton_sampler_splits,
|
sample)
|
||||||
sample)
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingTensors
|
from vllm.model_executor.sampling_metadata import SamplingTensors
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
|
||||||
|
get_num_triton_sampler_splits)
|
||||||
|
|
||||||
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
|
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
|
||||||
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
|
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
|
||||||
|
|||||||
@ -4,7 +4,10 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||||
|
|
||||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|||||||
@ -1,14 +1,22 @@
|
|||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
||||||
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"fused_moe",
|
|
||||||
"fused_topk",
|
|
||||||
"fused_experts",
|
|
||||||
"get_config_file_name",
|
|
||||||
"grouped_topk",
|
|
||||||
"FusedMoE",
|
"FusedMoE",
|
||||||
"FusedMoEMethodBase",
|
"FusedMoEMethodBase",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||||
|
grouped_topk)
|
||||||
|
|
||||||
|
__all__ += [
|
||||||
|
"fused_moe",
|
||||||
|
"fused_topk",
|
||||||
|
"fused_experts",
|
||||||
|
"get_config_file_name",
|
||||||
|
"grouped_topk",
|
||||||
|
]
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import math
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -6,21 +5,10 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
||||||
|
from vllm.triton_utils.sample import get_num_triton_sampler_splits
|
||||||
|
|
||||||
_EPS = 1e-6
|
_EPS = 1e-6
|
||||||
|
|
||||||
# This is a hardcoded limit in Triton (max block size).
|
|
||||||
MAX_TRITON_N_COLS = 131072
|
|
||||||
|
|
||||||
|
|
||||||
def get_num_triton_sampler_splits(n_cols: int) -> int:
|
|
||||||
"""Get the number of splits to use for Triton sampling.
|
|
||||||
|
|
||||||
Triton has a limit on the number of columns it can handle, so we need to
|
|
||||||
split the tensor and call the kernel multiple times if it's too large.
|
|
||||||
"""
|
|
||||||
return math.ceil(n_cols / MAX_TRITON_N_COLS)
|
|
||||||
|
|
||||||
|
|
||||||
def _multi_split_sample(
|
def _multi_split_sample(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
|
|||||||
@ -6,8 +6,7 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||||
fused_moe)
|
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -404,6 +403,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
return fused_moe(x,
|
return fused_moe(x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
|
|||||||
@ -6,7 +6,11 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||||
|
|
||||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||||
SamplingTensors,
|
SamplingTensors,
|
||||||
SequenceGroupToSample)
|
SequenceGroupToSample)
|
||||||
|
|||||||
@ -5,9 +5,9 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
|
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.triton_utils.sample import get_num_triton_sampler_splits
|
||||||
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||||
make_tensor_with_pad, maybe_expand_dim)
|
make_tensor_with_pad, maybe_expand_dim)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
from vllm.triton_utils.custom_cache_manager import (
|
from vllm.triton_utils.importing import HAS_TRITON
|
||||||
maybe_set_triton_cache_manager)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["HAS_TRITON"]
|
||||||
"maybe_set_triton_cache_manager",
|
|
||||||
]
|
if HAS_TRITON:
|
||||||
|
|
||||||
|
from vllm.triton_utils.custom_cache_manager import (
|
||||||
|
maybe_set_triton_cache_manager)
|
||||||
|
|
||||||
|
__all__ += ["maybe_set_triton_cache_manager"]
|
||||||
|
|||||||
11
vllm/triton_utils/importing.py
Normal file
11
vllm/triton_utils/importing.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
HAS_TRITON = find_spec("triton") is not None
|
||||||
|
|
||||||
|
if not HAS_TRITON:
|
||||||
|
logger.info("Triton not installed; certain GPU-related functions"
|
||||||
|
" will be not be available.")
|
||||||
13
vllm/triton_utils/sample.py
Normal file
13
vllm/triton_utils/sample.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
# This is a hardcoded limit in Triton (max block size).
|
||||||
|
MAX_TRITON_N_COLS = 131072
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_triton_sampler_splits(n_cols: int) -> int:
|
||||||
|
"""Get the number of splits to use for Triton sampling.
|
||||||
|
|
||||||
|
Triton has a limit on the number of columns it can handle, so we need to
|
||||||
|
split the tensor and call the kernel multiple times if it's too large.
|
||||||
|
"""
|
||||||
|
return math.ceil(n_cols / MAX_TRITON_N_COLS)
|
||||||
Loading…
Reference in New Issue
Block a user