[Misc] Use vllm-flash-attn instead of flash-attn (#4686)
This commit is contained in:
parent
230c4b38c1
commit
89579a201f
21
Dockerfile
21
Dockerfile
@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
pip cache remove vllm_nccl*
|
pip cache remove vllm_nccl*
|
||||||
#################### EXTENSION Build IMAGE ####################
|
#################### EXTENSION Build IMAGE ####################
|
||||||
|
|
||||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
|
||||||
FROM dev as flash-attn-builder
|
|
||||||
# max jobs used for build
|
|
||||||
ARG max_jobs=2
|
|
||||||
ENV MAX_JOBS=${max_jobs}
|
|
||||||
# flash attention version
|
|
||||||
ARG flash_attn_version=v2.5.8
|
|
||||||
ENV FLASH_ATTN_VERSION=${flash_attn_version}
|
|
||||||
|
|
||||||
WORKDIR /usr/src/flash-attention-v2
|
|
||||||
|
|
||||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
|
||||||
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
|
|
||||||
--no-build-isolation --no-deps --no-cache-dir
|
|
||||||
|
|
||||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
|
||||||
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
# image with vLLM installed
|
# image with vLLM installed
|
||||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
|
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
|
||||||
@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/
|
|||||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install dist/*.whl --verbose
|
pip install dist/*.whl --verbose
|
||||||
|
|
||||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
|
|||||||
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
|
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
|
||||||
torch == 2.3.0
|
torch == 2.3.0
|
||||||
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
||||||
|
vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0
|
||||||
|
|||||||
12
setup.py
12
setup.py
@ -355,13 +355,17 @@ def get_requirements() -> List[str]:
|
|||||||
|
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
requirements = _read_requirements("requirements-cuda.txt")
|
requirements = _read_requirements("requirements-cuda.txt")
|
||||||
cuda_major = torch.version.cuda.split(".")[0]
|
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||||
modified_requirements = []
|
modified_requirements = []
|
||||||
for req in requirements:
|
for req in requirements:
|
||||||
if "vllm-nccl-cu12" in req:
|
if "vllm-nccl-cu12" in req:
|
||||||
modified_requirements.append(
|
req = req.replace("vllm-nccl-cu12",
|
||||||
req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}"))
|
f"vllm-nccl-cu{cuda_major}")
|
||||||
else:
|
elif ("vllm-flash-attn" in req
|
||||||
|
and not (cuda_major == "12" and cuda_minor == "1")):
|
||||||
|
# vllm-flash-attn is built only for CUDA 12.1.
|
||||||
|
# Skip for other versions.
|
||||||
|
continue
|
||||||
modified_requirements.append(req)
|
modified_requirements.append(req)
|
||||||
requirements = modified_requirements
|
requirements = modified_requirements
|
||||||
elif _is_hip():
|
elif _is_hip():
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flash_attn import flash_attn_varlen_func
|
from vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
|
|||||||
@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
|||||||
|
|
||||||
import flashinfer
|
import flashinfer
|
||||||
import torch
|
import torch
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||||
|
from vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
|||||||
@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
|||||||
return _Backend.XFORMERS
|
return _Backend.XFORMERS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn # noqa: F401
|
import vllm_flash_attn # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention-2 backend because the flash_attn "
|
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
|
||||||
"package is not found. Please install it for better performance.")
|
"package is not found. `pip install vllm-flash-attn` for better "
|
||||||
|
"performance.")
|
||||||
return _Backend.XFORMERS
|
return _Backend.XFORMERS
|
||||||
|
|
||||||
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
|
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user