Add Alibi to MHA, test with Baichuan-13B

This commit is contained in:
Tri Dao 2023-12-21 22:49:55 -08:00
parent 701b51bfc3
commit c3b2196652
5 changed files with 84 additions and 41 deletions

View File

@ -109,29 +109,14 @@ def remap_state_dict_hf_baichuan(state_dict, config):
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
for l in range(config.n_layer):
# pop rotary_emb.inv_freq from state dict
state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq")
state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None)
return state_dict
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> PretrainedConfig:
"""Load a BaiChuanConfig from a checkpoint path."""
config = AutoConfig.from_pretrained(
Path(checkpoint_path) / model_name, trust_remote_code=True
)
return config
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [
torch.load(path, map_location="cpu")
for path in sorted(
(Path(checkpoint_path) / model_name).glob("pytorch_model*.bin")
)
]
def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
# HACK: the config doesn't have say whether it's rotary or alibi.
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
use_rotary = baichuan_config.hidden_size < 5000
return GPT2Config(
vocab_size=baichuan_config.vocab_size,
n_positions=0, # No absolute position embedding
@ -151,8 +136,10 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con
# These are new arguments not in the original GPT2Config
pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything
rms_norm=True,
rotary_emb_fraction=1.0,
rotary_emb_fraction=1.0 if use_rotary else 0.0,
rotary_emb_interleaved=False,
use_alibi=not use_rotary,
use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,

View File

@ -85,6 +85,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
use_alibi = getattr(config, "use_alibi", False)
use_flash_attn = getattr(config, "use_flash_attn", False)
fused_bias_fc = getattr(config, "fused_bias_fc", False)
if not fused_bias_fc:
@ -116,6 +117,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_base=rotary_emb_base,
rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved,
use_alibi=use_alibi,
use_flash_attn=use_flash_attn,
**serial_kwargs,
**parallel_kwargs,

View File

@ -33,6 +33,23 @@ except ImportError:
RotaryEmbedding = None
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
def get_alibi_slopes(nheads):
def get_slopes_power_of_2(nheads):
start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
ratio = start
return [start * ratio**i for i in range(nheads)]
if math.log2(nheads).is_integer():
return get_slopes_power_of_2(nheads)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
)
class FlashSelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
@ -44,13 +61,14 @@ class FlashSelfAttention(nn.Module):
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
super().__init__()
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""Implements the multihead softmax attention.
@ -84,6 +102,7 @@ class FlashSelfAttention(nn.Module):
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
)
else:
return flash_attn_qkvpacked_func(
@ -91,6 +110,7 @@ class FlashSelfAttention(nn.Module):
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
)
@ -105,13 +125,14 @@ class FlashCrossAttention(nn.Module):
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
super().__init__()
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
def forward(
self,
@ -158,6 +179,7 @@ class FlashCrossAttention(nn.Module):
self.drop.p if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal=causal,
alibi_slopes=self.alibi_slopes,
)
else:
batch_size, seqlen_q = q.shape[0], q.shape[1]
@ -169,6 +191,7 @@ class FlashCrossAttention(nn.Module):
self.drop.p if self.training else 0.0,
causal=causal,
softmax_scale=self.softmax_scale,
alibi_slopes=self.alibi_slopes,
)
@ -315,8 +338,8 @@ def _update_kv_cache(kv, inference_params, layer_idx):
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
assert batch_end <= kv_cache.shape[0]
assert sequence_end <= kv_cache.shape[1]
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
return kv_cache[batch_start:batch_end, :sequence_end, ...]
@ -342,6 +365,7 @@ class MHA(nn.Module):
rotary_emb_base=10000.0,
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
fused_bias_fc=False,
use_flash_attn=False,
return_residual=False,
@ -366,6 +390,11 @@ class MHA(nn.Module):
self.use_flash_attn = use_flash_attn
self.return_residual = return_residual
self.checkpointing = checkpointing
if use_alibi:
assert use_flash_attn, "ALiBi code path requires flash_attn"
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
else:
alibi_slopes = None
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
@ -395,8 +424,16 @@ class MHA(nn.Module):
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
)
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
if use_flash_attn
else SelfAttention
)
inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
if use_flash_attn
else CrossAttention
)
if not self.cross_attn:
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
else:
@ -413,7 +450,9 @@ class MHA(nn.Module):
)
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
self.inner_attn = inner_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
causal=causal,
softmax_scale=softmax_scale,
attention_dropout=dropout,
)
self.inner_cross_attn = inner_cross_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
@ -672,6 +711,7 @@ class ParallelMHA(nn.Module):
rotary_emb_base=10000.0,
rotary_emb_scale_base=None,
rotary_emb_interleaved=False,
use_alibi=False,
use_flash_attn=False,
checkpointing=False,
sequence_parallel=True,
@ -707,6 +747,18 @@ class ParallelMHA(nn.Module):
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
if use_alibi:
assert use_flash_attn, "ALiBi code path requires flash_attn"
num_heads_local = math.ceil(self.num_heads / self.world_size)
alibi_slopes = torch.tensor(
get_alibi_slopes(num_heads)[
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
],
device=device,
)
else:
alibi_slopes = None
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary_emb is not installed"
self.rotary_emb = RotaryEmbedding(
@ -728,8 +780,16 @@ class ParallelMHA(nn.Module):
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
**factory_kwargs,
)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
if use_flash_attn
else SelfAttention
)
inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
if use_flash_attn
else CrossAttention
)
self.inner_attn = inner_attn_cls(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)

View File

@ -105,7 +105,7 @@ class GatedMlp(nn.Module):
activation=F.sigmoid,
bias1=True,
bias2=True,
multiple_of=256,
multiple_of=128,
return_residual=False,
device=None,
dtype=None,
@ -148,7 +148,7 @@ class ParallelGatedMlp(nn.Module):
activation=F.sigmoid,
bias1=True,
bias2=True,
multiple_of=256,
multiple_of=128,
sequence_parallel=True,
device=None,
dtype=None,

View File

@ -2,8 +2,6 @@ import os
import time
from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import torch
import pytest
@ -20,16 +18,12 @@ from flash_attn.models.baichuan import (
remap_state_dict_hf_baichuan,
baichuan_config_to_gpt2_config,
)
from flash_attn.models.baichuan import (
config_from_checkpoint,
state_dicts_from_checkpoint,
)
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
def test_baichuan_state_dict(model_name):
config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@ -45,7 +39,7 @@ def test_baichuan_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
def test_baichuan_optimized(model_name):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@ -122,7 +116,7 @@ def test_baichuan_optimized(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
def test_baichuan_parallel_forward(model_name, world_size):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@ -217,7 +211,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
).abs().max().item()
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
def test_baichuan_generation(model_name):
dtype = torch.float16
device = "cuda"