Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
This commit is contained in:
commit
315fd31f0c
@ -38,7 +38,7 @@ and experiment with. The notations in the Triton implementation are also closer
|
||||
to what's used in our paper.
|
||||
|
||||
|
||||
## Beta release (0.2).
|
||||
## Installation and features
|
||||
|
||||
Requirements:
|
||||
- CUDA 11.4 and above.
|
||||
|
||||
@ -122,7 +122,9 @@ int gemm_bias_act_lt(
|
||||
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
||||
// setting this to 1M.
|
||||
size_t workspaceSize = 1024 * 1024;
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
void* workspace = at::empty(
|
||||
{static_cast<int64_t>(workspaceSize)},
|
||||
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|
||||
@ -296,7 +298,8 @@ int gemm_bgradb_lt(
|
||||
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
||||
// setting this to 1M.
|
||||
size_t workspaceSize = 1024 * 1024;
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
void* workspace = at::empty(
|
||||
{static_cast<int64_t>(workspaceSize)},
|
||||
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|
||||
@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
|
||||
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
||||
// setting this to 1M.
|
||||
size_t workspaceSize = 1024 * 1024;
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
void* workspace = at::empty(
|
||||
{static_cast<int64_t>(workspaceSize)},
|
||||
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|
||||
|
||||
@ -17,7 +17,7 @@ class Mlp(nn.Module):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.return_residual = return_residual
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
|
||||
self.activation = activation
|
||||
|
||||
2
setup.py
2
setup.py
@ -162,7 +162,7 @@ ext_modules.append(
|
||||
|
||||
setup(
|
||||
name="flash_attn",
|
||||
version="0.2.8",
|
||||
version="1.0.1",
|
||||
packages=find_packages(
|
||||
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
|
||||
),
|
||||
|
||||
@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
|
||||
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
|
||||
|
||||
# Install FlashAttention
|
||||
RUN pip install flash-attn==0.2.8
|
||||
RUN pip install flash-attn==1.0.1
|
||||
|
||||
# Install CUDA extensions for cross-entropy, fused dense, layer norm
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention \
|
||||
&& cd flash-attention && git checkout v0.2.8 \
|
||||
&& cd flash-attention && git checkout v1.0.1 \
|
||||
&& cd csrc/fused_softmax && pip install . && cd ../../ \
|
||||
&& cd csrc/rotary && pip install . && cd ../../ \
|
||||
&& cd csrc/xentropy && pip install . && cd ../../ \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user