Add MLP, MHA, Block, Embedding modules

This commit is contained in:
Tri Dao 2022-11-13 22:06:44 -08:00
parent fa6d1ce44f
commit d4b320b31f
7 changed files with 556 additions and 3 deletions

129
flash_attn/modules/block.py Normal file
View File

@ -0,0 +1,129 @@
# Copyright (c) 2022, Tri Dao.
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import StochasticDepth
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0.,
fused_dropout_add_ln=False):
super().__init__()
self.prenorm = prenorm
self.fused_dropout_add_ln = fused_dropout_add_ln
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout)
self.drop_path1 = StochasticDepth(drop_path, mode='row')
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout)
self.drop_path2 = StochasticDepth(drop_path, mode='row')
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_kwargs=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = LayerNorm(residual)
"""
if self.prenorm:
assert residual is not None
mixer_out = self.mixer(hidden_states,
**(mixer_kwargs if mixer_kwargs is not None else {}))
if not self.fused_dropout_add_ln:
residual = self.drop_path1(self.dropout1(mixer_out)) + residual
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
)
hidden_states, residual = dropout_add_layer_norm(
mixer_out, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=True
)
if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states)
if not self.fused_dropout_add_ln:
residual = self.drop_path2(self.dropout2(mlp_out)) + residual
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
)
hidden_states, residual = dropout_add_layer_norm(
mlp_out, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=True
)
return hidden_states, residual
else:
assert residual is None
mixer_out = self.mixer(hidden_states,
**(mixer_kwargs if mixer_kwargs is not None else {}))
if not self.fused_dropout_add_ln:
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
+ hidden_states).to(dtype=self.norm1.weight.dtype))
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
)
hidden_states = dropout_add_layer_norm(
mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=False
)
if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states)
if not self.fused_dropout_add_ln:
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
+ hidden_states).to(dtype=self.norm2.weight.dtype))
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
)
hidden_states = dropout_add_layer_norm(
mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=False
)
return hidden_states

View File

@ -0,0 +1,35 @@
# Copyright (c) 2022, Tri Dao.
import torch
import torch.nn as nn
from einops import repeat
class GPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim)
def forward(self, input_ids, position_ids=None):
"""
input_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
input_embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = repeat(torch.arange(seqlen, dtype=torch.long,
device=input_ids.device),
's -> b s', b=batch_size)
position_embeddings = self.position_embeddings(position_ids)
return input_embeddings + position_embeddings
else:
return input_embeddings

319
flash_attn/modules/mha.py Normal file
View File

@ -0,0 +1,319 @@
# Copyright (c) 2022, Tri Dao.
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
except ImportError:
flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None
try:
from flash_attn.ops.flash_attn_triton import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
except ImportError:
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
try:
from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseResidual
except ImportError:
FusedDenseTD, FusedDenseResidual = None, None
try:
from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
RotaryEmbedding = None
class FlashSelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
triton=False, device=None, dtype=None):
super().__init__()
if attention_dropout != 0.0 or not triton:
assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed'
if attention_dropout == 0.0 and triton:
assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.triton = triton
def forward(self, qkv):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
"""
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
if self.triton and (self.dropout_p == 0 or not self.training): # Triton version doesn't support dropout
output = flash_attn_qkvpacked_func(qkv, None, self.causal, self.softmax_scale)
else:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class FlashCrossAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
triton=False, device=None, dtype=None):
super().__init__()
if attention_dropout != 0.0 or not triton:
assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed'
if attention_dropout == 0.0 and triton:
assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.triton = triton
def forward(self, q, kv):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
if self.triton and (self.dropout_p == 0.0 or not self.training): # Triton version doesn't support dropout
output = flash_attn_kvpacked_func(q, kv, None, self.causal, self.softmax_scale)
else:
q = rearrange(q, 'b s ... -> (b s) ...')
kv = rearrange(kv, 'b s ... -> (b s) ...')
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q,
dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k,
dtype=torch.int32, device=kv.device)
output = flash_attn_unpadded_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class SelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
"""
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if self.causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
return output
class CrossAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, kv):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
"""
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
k, v = kv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
if self.causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
return output
class LinearResidual(nn.Linear):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDenseResidual.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input), input
class MHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0,
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
checkpointing=False, device=None, dtype=None) -> None:
"""
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.cross_attn = cross_attn
self.causal = causal
self.dwconv = dwconv
self.rotary_emb_dim = rotary_emb_dim
self.return_residual = return_residual
self.checkpointing = checkpointing
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
if self.rotary_emb_dim > 0:
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim)
if fused_bias_fc and FusedDenseTD is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
linear_resid_cls = LinearResidual if not fused_bias_fc else FusedDenseResidual
if not self.cross_attn:
if not self.return_residual:
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
else:
self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
if self.dwconv:
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
else:
# TODO: use the residual linear class for Wq
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
if self.dwconv:
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
groups=embed_dim)
self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2,
groups=2 * embed_dim)
inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout, **factory_kwargs)
# output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def forward(self, x, x_kv=None):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
"""
if not self.cross_attn:
if not self.return_residual:
qkv = self.Wqkv(x)
else:
qkv, x = self.Wqkv(x)
if self.dwconv:
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
if not self.checkpointing:
context = self.inner_attn(qkv)
else:
# context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv)
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv)
else:
q = rearrange(self.Wq(x), 'b s (h d) -> b s h d', h=self.num_heads)
kv = rearrange(self.Wkv(x if x_kv is None else x_kv), 'b s (two h d) -> b s two h d',
two=2, h=self.num_heads)
if self.dwconv:
q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous()
if not self.checkpointing:
context = self.inner_attn(q, kv)
else:
# context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv)
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv)
out = self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))
return out if not self.return_residual else (out, x)

72
flash_attn/modules/mlp.py Normal file
View File

@ -0,0 +1,72 @@
# Copyright (c) 2022, Tri Dao.
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn.ops.fused_dense import fused_dense_gelu_dense_function_td
from flash_attn.ops.fused_dense import fused_dense_res_gelu_dense_function_td
except ImportError:
fused_dense_gelu_dense_function_td = None
fused_dense_res_gelu_dense_function_td = None
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class FusedDenseGeluDense(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
checkpoint_lvl=0, heuristic=0, return_residual=False, device=None, dtype=None):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert bias == True, "DenseGeluDense module without bias is currently not supported"
assert (fused_dense_gelu_dense_function_td is not None
and fused_dense_res_gelu_dense_function_td is not None), 'fused_dense_lib is not installed'
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
assert x.dtype in [torch.float16, torch.bfloat16]
assert x.is_cuda
fn = (fused_dense_gelu_dense_function_td if not self.return_residual
else fused_dense_res_gelu_dense_function_td)
return fn(x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, self.heuristic)

View File

@ -8,9 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
# from src.ops.triton.triton_matmul import matmul_dgelu
from flash_attn.ops.gelu_activation import gelu_bwd
# from src.ops.gelu_activation import gelu_bwd, bias_gelu, bias_gelu_back
# implements fused GEMM+bias in forward pass using mlp_cuda from apex

View File

@ -6,7 +6,7 @@ import pytest
from einops import rearrange
from flash_attn.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0)