Update training Dockerfile to use flash-attn==0.2.6
This commit is contained in:
parent
029617179f
commit
984d5204e2
@ -2,7 +2,7 @@
|
||||
# ARG COMPAT=0
|
||||
ARG PERSONAL=0
|
||||
# FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0
|
||||
FROM nvcr.io/nvidia/pytorch:22.11-py3 as base
|
||||
FROM nvcr.io/nvidia/pytorch:22.12-py3 as base
|
||||
|
||||
ENV HOST docker
|
||||
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
|
||||
@ -67,30 +67,29 @@ ENV PIP_NO_CACHE_DIR=1
|
||||
# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
|
||||
|
||||
# xgboost conflicts with deepspeed
|
||||
RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.5
|
||||
RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7
|
||||
|
||||
# General packages that we don't care about the version
|
||||
# zstandard to extract the_pile dataset
|
||||
# psutil to get the number of cpu physical cores
|
||||
# twine to upload package to PyPI
|
||||
# ninja is broken for some reason, it returns error code 245
|
||||
RUN pip uninstall -y ninja && pip install ninja
|
||||
RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine \
|
||||
RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \
|
||||
&& python -m spacy download en_core_web_sm
|
||||
# hydra
|
||||
RUN pip install hydra-core==1.2.0 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
|
||||
RUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
|
||||
# Core packages
|
||||
RUN pip install transformers==4.24.0 datasets==2.7.1 pytorch-lightning==1.7.7 triton==2.0.0.dev20221120 wandb==0.13.5 timm==0.6.12 torchmetrics==0.10.3
|
||||
RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 triton==2.0.0.dev20221202 wandb==0.13.7 timm==0.6.12 torchmetrics==0.11.0
|
||||
|
||||
# For MLPerf
|
||||
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
|
||||
|
||||
# Install FlashAttention
|
||||
RUN pip install flash-attn==0.2.2
|
||||
RUN pip install flash-attn==0.2.6.post
|
||||
|
||||
# Install CUDA extensions for cross-entropy, fused dense, layer norm
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention \
|
||||
&& cd flash-attention && git checkout v0.2.2 \
|
||||
# Install CUDA extensions for cross-entropy, fused dense, layer norm
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention \
|
||||
&& cd flash-attention && git checkout v0.2.6 \
|
||||
&& cd csrc/fused_softmax && pip install . && cd ../../ \
|
||||
&& cd csrc/rotary && pip install . && cd ../../ \
|
||||
&& cd csrc/xentropy && pip install . && cd ../../ \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user