diff --git a/flash_attn/rotary.py b/flash_attn/layers/rotary.py similarity index 100% rename from flash_attn/rotary.py rename to flash_attn/layers/rotary.py diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py new file mode 100644 index 0000000..073ebc6 --- /dev/null +++ b/flash_attn/modules/block.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022, Tri Dao. + +from typing import Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from torchvision.ops import StochasticDepth + +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import Mlp + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + dropout_add_layer_norm = None + + +class Block(nn.Module): + + def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0., + fused_dropout_add_ln=False): + super().__init__() + self.prenorm = prenorm + self.fused_dropout_add_ln = fused_dropout_add_ln + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout) + self.drop_path1 = StochasticDepth(drop_path, mode='row') + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + if not isinstance(self.mlp, nn.Identity): + self.dropout2 = dropout_cls(resid_dropout) + self.drop_path2 = StochasticDepth(drop_path, mode='row') + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed' + assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) + + def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, + mixer_kwargs=None): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: if postnorm, residual=None, If prenorm, hidden_states = LayerNorm(residual) + """ + if self.prenorm: + assert residual is not None + mixer_out = self.mixer(hidden_states, + **(mixer_kwargs if mixer_kwargs is not None else {})) + if not self.fused_dropout_add_ln: + residual = self.drop_path1(self.dropout1(mixer_out)) + residual + hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1(torch.ones( + mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype) + ) + hidden_states, residual = dropout_add_layer_norm( + mixer_out, residual, self.norm1.weight, self.norm1.bias, + self.dropout1.p if self.training else 0.0, self.norm1.eps, + rowscale=rowscale1, prenorm=True + ) + if not isinstance(self.mlp, nn.Identity): + mlp_out = self.mlp(hidden_states) + if not self.fused_dropout_add_ln: + residual = self.drop_path2(self.dropout2(mlp_out)) + residual + hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2(torch.ones( + mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype) + ) + hidden_states, residual = dropout_add_layer_norm( + mlp_out, residual, self.norm2.weight, self.norm2.bias, + self.dropout2.p if self.training else 0.0, self.norm2.eps, + rowscale=rowscale2, prenorm=True + ) + return hidden_states, residual + else: + assert residual is None + mixer_out = self.mixer(hidden_states, + **(mixer_kwargs if mixer_kwargs is not None else {})) + if not self.fused_dropout_add_ln: + hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) + + hidden_states).to(dtype=self.norm1.weight.dtype)) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1(torch.ones( + mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype) + ) + hidden_states = dropout_add_layer_norm( + mixer_out, hidden_states, self.norm1.weight, self.norm1.bias, + self.dropout1.p if self.training else 0.0, self.norm1.eps, + rowscale=rowscale1, prenorm=False + ) + if not isinstance(self.mlp, nn.Identity): + mlp_out = self.mlp(hidden_states) + if not self.fused_dropout_add_ln: + hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) + + hidden_states).to(dtype=self.norm2.weight.dtype)) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2(torch.ones( + mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype) + ) + hidden_states = dropout_add_layer_norm( + mlp_out, hidden_states, self.norm2.weight, self.norm2.bias, + self.dropout2.p if self.training else 0.0, self.norm2.eps, + rowscale=rowscale2, prenorm=False + ) + return hidden_states diff --git a/flash_attn/modules/embedding.py b/flash_attn/modules/embedding.py new file mode 100644 index 0000000..03a8f51 --- /dev/null +++ b/flash_attn/modules/embedding.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022, Tri Dao. + +import torch +import torch.nn as nn + +from einops import repeat + + +class GPT2Embeddings(nn.Module): + + def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None): + """ + If max_position_embeddings <= 0, there's no position embeddings + """ + super().__init__() + self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim) + + def forward(self, input_ids, position_ids=None): + """ + input_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + input_embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = repeat(torch.arange(seqlen, dtype=torch.long, + device=input_ids.device), + 's -> b s', b=batch_size) + position_embeddings = self.position_embeddings(position_ids) + return input_embeddings + position_embeddings + else: + return input_embeddings diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py new file mode 100644 index 0000000..d465289 --- /dev/null +++ b/flash_attn/modules/mha.py @@ -0,0 +1,319 @@ +# Copyright (c) 2022, Tri Dao. + +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func + from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func +except ImportError: + flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None + +try: + from flash_attn.ops.flash_attn_triton import flash_attn_qkvpacked_func, flash_attn_kvpacked_func +except ImportError: + flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None + +try: + from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseResidual +except ImportError: + FusedDenseTD, FusedDenseResidual = None, None + +try: + from flash_attn.layers.rotary import RotaryEmbedding +except ImportError: + RotaryEmbedding = None + + +class FlashSelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + triton=False, device=None, dtype=None): + super().__init__() + if attention_dropout != 0.0 or not triton: + assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed' + if attention_dropout == 0.0 and triton: + assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + self.triton = triton + + def forward(self, qkv): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) + """ + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + if self.triton and (self.dropout_p == 0 or not self.training): # Triton version doesn't support dropout + output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale) + else: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=self.causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + return output + + +class FlashCrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + triton=False, device=None, dtype=None): + super().__init__() + if attention_dropout != 0.0 or not triton: + assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed' + if attention_dropout == 0.0 and triton: + assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + self.triton = triton + + def forward(self, q, kv): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H, D) + """ + assert q.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda and kv.is_cuda + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = kv.shape[1] + assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] + if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout + output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale) + else: + q = rearrange(q, 'b s ... -> (b s) ...') + kv = rearrange(kv, 'b s ... -> (b s) ...') + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, + dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, + dtype=torch.int32, device=kv.device) + output = flash_attn_unpadded_kvpacked_func( + q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=self.causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + return output + + +class SelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) + """ + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + q, k, v = qkv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) + if self.causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) + output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + return output + + +class CrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, kv): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H, D) + """ + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = kv.shape[1] + assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] + k, v = kv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) + if self.causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, + device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) + output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + return output + + +class LinearResidual(nn.Linear): + """Wrap nn.Linear to return the residual as well. For compatibility with FusedDenseResidual. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input), input + + +class MHA(nn.Module): + """Multi-head self-attention and cross-attention + """ + + def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0, + softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0, + fused_bias_fc=False, use_flash_attn=False, return_residual=False, + checkpointing=False, device=None, dtype=None) -> None: + """ + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.cross_attn = cross_attn + self.causal = causal + self.dwconv = dwconv + self.rotary_emb_dim = rotary_emb_dim + self.return_residual = return_residual + self.checkpointing = checkpointing + + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + + if self.rotary_emb_dim > 0: + assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' + assert RotaryEmbedding is not None, 'rotary_emb is not installed' + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim) + + if fused_bias_fc and FusedDenseTD is None: + raise ImportError('fused_dense is not installed') + linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD + linear_resid_cls = LinearResidual if not fused_bias_fc else FusedDenseResidual + if not self.cross_attn: + if not self.return_residual: + self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + else: + self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + if self.dwconv: + self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, + groups=3 * embed_dim) + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + else: + # TODO: use the residual linear class for Wq + self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + if self.dwconv: + self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2, + groups=embed_dim) + self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2, + groups=2 * embed_dim) + inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout, **factory_kwargs) + # output projection always have the bias (for now) + self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) + + def forward(self, x, x_kv=None): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) + x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. + """ + if not self.cross_attn: + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + if self.dwconv: + qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], + 'b d s -> b s d').contiguous() + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + if not self.checkpointing: + context = self.inner_attn(qkv) + else: + # context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv) + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv) + else: + q = rearrange(self.Wq(x), 'b s (h d) -> b s h d', h=self.num_heads) + kv = rearrange(self.Wkv(x if x_kv is None else x_kv), 'b s (two h d) -> b s two h d', + two=2, h=self.num_heads) + if self.dwconv: + q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2], + 'b d s -> b s d').contiguous() + kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2], + 'b d s -> b s d').contiguous() + if not self.checkpointing: + context = self.inner_attn(q, kv) + else: + # context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv) + context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv) + out = self.out_proj(rearrange(context, 'b s h d -> b s (h d)')) + return out if not self.return_residual else (out, x) diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py new file mode 100644 index 0000000..c7916ed --- /dev/null +++ b/flash_attn/modules/mlp.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022, Tri Dao. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from flash_attn.ops.fused_dense import fused_dense_gelu_dense_function_td + from flash_attn.ops.fused_dense import fused_dense_res_gelu_dense_function_td +except ImportError: + fused_dense_gelu_dense_function_td = None + fused_dense_res_gelu_dense_function_td = None + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + + +class FusedDenseGeluDense(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, bias=True, + checkpoint_lvl=0, heuristic=0, return_residual=False, device=None, dtype=None): + """ + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute gelu_in and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + assert checkpoint_lvl in [0, 1, 2] + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert bias == True, "DenseGeluDense module without bias is currently not supported" + assert (fused_dense_gelu_dense_function_td is not None + and fused_dense_res_gelu_dense_function_td is not None), 'fused_dense_lib is not installed' + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) + + def forward(self, x): + assert x.dtype in [torch.float16, torch.bfloat16] + assert x.is_cuda + fn = (fused_dense_gelu_dense_function_td if not self.return_residual + else fused_dense_res_gelu_dense_function_td) + return fn(x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, + self.checkpoint_lvl, self.heuristic) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 76ea25f..7b0a739 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -8,9 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd # import fused_dense_cuda # from apex import fused_dense_lib as fused_dense_cuda -# from src.ops.triton.triton_matmul import matmul_dgelu from flash_attn.ops.gelu_activation import gelu_bwd -# from src.ops.gelu_activation import gelu_bwd, bias_gelu, bias_gelu_back # implements fused GEMM+bias in forward pass using mlp_cuda from apex diff --git a/tests/test_rotary.py b/tests/test_rotary.py index fb29bce..12bd2d4 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -6,7 +6,7 @@ import pytest from einops import rearrange -from flash_attn.rotary import apply_rotary_emb_func, apply_rotary_emb_torch +from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0)