Merge branch 'HazyResearch:main' into enable_cuda_graph_capture

This commit is contained in:
Kirthi Shankar Sivamani 2023-04-12 22:42:24 -07:00 committed by GitHub
commit 315fd31f0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 12 additions and 8 deletions

View File

@ -38,7 +38,7 @@ and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
## Beta release (0.2).
## Installation and features
Requirements:
- CUDA 11.4 and above.

View File

@ -122,7 +122,9 @@ int gemm_bias_act_lt(
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
@ -296,7 +298,8 @@ int gemm_bgradb_lt(
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();

View File

@ -17,7 +17,7 @@ class Mlp(nn.Module):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = hidden_features or in_features * 4
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.activation = activation

View File

@ -162,7 +162,7 @@ ext_modules.append(
setup(
name="flash_attn",
version="0.2.8",
version="1.0.1",
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
),

View File

@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==0.2.8
RUN pip install flash-attn==1.0.1
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v0.2.8 \
&& cd flash-attention && git checkout v1.0.1 \
&& cd csrc/fused_softmax && pip install . && cd ../../ \
&& cd csrc/rotary && pip install . && cd ../../ \
&& cd csrc/xentropy && pip install . && cd ../../ \