Update pip install instructions, bump to 0.2

This commit is contained in:
Tri Dao 2022-11-15 14:10:36 -08:00
parent 56aa49037d
commit 4040256b5e
2 changed files with 10 additions and 5 deletions

View File

@ -24,9 +24,14 @@ and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper. to what's used in our paper.
## Alpha release (0.1). ## Beta release (0.2).
To compile (requiring CUDA 11, NVCC, and an Turing or Ampere GPU): To install (requiring CUDA 11, NVCC, and an Turing or Ampere GPU):
```sh
pip install flash-attn
```
Alternatively you can compile from source:
``` ```
python setup.py install python setup.py install
``` ```
@ -44,7 +49,7 @@ FlashAttention currently supports:
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100. 3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100.
Our tentative roadmap: Our tentative roadmap:
1. [Jun 2022] Make package pip-installable. 1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done]. 2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3. [Jun 2022] Refactor to use Cutlass. 3. [Jun 2022] Refactor to use Cutlass.
4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done]. 4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].

View File

@ -152,7 +152,7 @@ ext_modules.append(
setup( setup(
name="flash_attn", name="flash_attn",
version="0.1", version="0.2",
packages=find_packages( packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
), ),
@ -164,7 +164,7 @@ setup(
url="https://github.com/HazyResearch/flash-attention", url="https://github.com/HazyResearch/flash-attention",
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: BSD License",
"Operating System :: Unix", "Operating System :: Unix",
], ],
ext_modules=ext_modules, ext_modules=ext_modules,