Add Alibi to MHA, test with Baichuan-13B
This commit is contained in:
parent
701b51bfc3
commit
c3b2196652
@ -109,29 +109,14 @@ def remap_state_dict_hf_baichuan(state_dict, config):
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
for l in range(config.n_layer):
|
||||
# pop rotary_emb.inv_freq from state dict
|
||||
state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq")
|
||||
state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None)
|
||||
return state_dict
|
||||
|
||||
|
||||
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> PretrainedConfig:
|
||||
"""Load a BaiChuanConfig from a checkpoint path."""
|
||||
config = AutoConfig.from_pretrained(
|
||||
Path(checkpoint_path) / model_name, trust_remote_code=True
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
|
||||
# Need to sort, otherwise we mess up the ordering and the weights are wrong
|
||||
return [
|
||||
torch.load(path, map_location="cpu")
|
||||
for path in sorted(
|
||||
(Path(checkpoint_path) / model_name).glob("pytorch_model*.bin")
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
|
||||
# HACK: the config doesn't have say whether it's rotary or alibi.
|
||||
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
|
||||
use_rotary = baichuan_config.hidden_size < 5000
|
||||
return GPT2Config(
|
||||
vocab_size=baichuan_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
@ -151,8 +136,10 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con
|
||||
# These are new arguments not in the original GPT2Config
|
||||
pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything
|
||||
rms_norm=True,
|
||||
rotary_emb_fraction=1.0,
|
||||
rotary_emb_fraction=1.0 if use_rotary else 0.0,
|
||||
rotary_emb_interleaved=False,
|
||||
use_alibi=not use_rotary,
|
||||
use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
|
||||
tie_word_embeddings=False,
|
||||
qkv_proj_bias=False,
|
||||
out_proj_bias=False,
|
||||
|
||||
@ -85,6 +85,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
|
||||
rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
|
||||
rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
|
||||
use_alibi = getattr(config, "use_alibi", False)
|
||||
use_flash_attn = getattr(config, "use_flash_attn", False)
|
||||
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
||||
if not fused_bias_fc:
|
||||
@ -116,6 +117,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
rotary_emb_base=rotary_emb_base,
|
||||
rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
rotary_emb_interleaved=rotary_emb_interleaved,
|
||||
use_alibi=use_alibi,
|
||||
use_flash_attn=use_flash_attn,
|
||||
**serial_kwargs,
|
||||
**parallel_kwargs,
|
||||
|
||||
@ -33,6 +33,23 @@ except ImportError:
|
||||
RotaryEmbedding = None
|
||||
|
||||
|
||||
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
|
||||
def get_alibi_slopes(nheads):
|
||||
def get_slopes_power_of_2(nheads):
|
||||
start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(nheads)]
|
||||
|
||||
if math.log2(nheads).is_integer():
|
||||
return get_slopes_power_of_2(nheads)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
|
||||
)
|
||||
|
||||
|
||||
class FlashSelfAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
@ -44,13 +61,14 @@ class FlashSelfAttention(nn.Module):
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
|
||||
super().__init__()
|
||||
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
||||
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
|
||||
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
@ -84,6 +102,7 @@ class FlashSelfAttention(nn.Module):
|
||||
self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
return flash_attn_qkvpacked_func(
|
||||
@ -91,6 +110,7 @@ class FlashSelfAttention(nn.Module):
|
||||
self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
|
||||
|
||||
@ -105,13 +125,14 @@ class FlashCrossAttention(nn.Module):
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
|
||||
super().__init__()
|
||||
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
||||
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -158,6 +179,7 @@ class FlashCrossAttention(nn.Module):
|
||||
self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
@ -169,6 +191,7 @@ class FlashCrossAttention(nn.Module):
|
||||
self.drop.p if self.training else 0.0,
|
||||
causal=causal,
|
||||
softmax_scale=self.softmax_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
|
||||
|
||||
@ -315,8 +338,8 @@ def _update_kv_cache(kv, inference_params, layer_idx):
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
sequence_start = inference_params.seqlen_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
||||
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
||||
assert batch_end <= kv_cache.shape[0]
|
||||
assert sequence_end <= kv_cache.shape[1]
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
@ -342,6 +365,7 @@ class MHA(nn.Module):
|
||||
rotary_emb_base=10000.0,
|
||||
rotary_emb_scale_base=None,
|
||||
rotary_emb_interleaved=False,
|
||||
use_alibi=False,
|
||||
fused_bias_fc=False,
|
||||
use_flash_attn=False,
|
||||
return_residual=False,
|
||||
@ -366,6 +390,11 @@ class MHA(nn.Module):
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.return_residual = return_residual
|
||||
self.checkpointing = checkpointing
|
||||
if use_alibi:
|
||||
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
||||
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
||||
@ -395,8 +424,16 @@ class MHA(nn.Module):
|
||||
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
||||
)
|
||||
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
inner_attn_cls = (
|
||||
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
|
||||
if use_flash_attn
|
||||
else SelfAttention
|
||||
)
|
||||
inner_cross_attn_cls = (
|
||||
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
|
||||
if use_flash_attn
|
||||
else CrossAttention
|
||||
)
|
||||
if not self.cross_attn:
|
||||
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
||||
else:
|
||||
@ -413,7 +450,9 @@ class MHA(nn.Module):
|
||||
)
|
||||
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
|
||||
self.inner_attn = inner_attn_cls(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
causal=causal,
|
||||
softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout,
|
||||
)
|
||||
self.inner_cross_attn = inner_cross_attn_cls(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
@ -672,6 +711,7 @@ class ParallelMHA(nn.Module):
|
||||
rotary_emb_base=10000.0,
|
||||
rotary_emb_scale_base=None,
|
||||
rotary_emb_interleaved=False,
|
||||
use_alibi=False,
|
||||
use_flash_attn=False,
|
||||
checkpointing=False,
|
||||
sequence_parallel=True,
|
||||
@ -707,6 +747,18 @@ class ParallelMHA(nn.Module):
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||
|
||||
if use_alibi:
|
||||
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
||||
num_heads_local = math.ceil(self.num_heads / self.world_size)
|
||||
alibi_slopes = torch.tensor(
|
||||
get_alibi_slopes(num_heads)[
|
||||
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
@ -728,8 +780,16 @@ class ParallelMHA(nn.Module):
|
||||
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
|
||||
**factory_kwargs,
|
||||
)
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
inner_attn_cls = (
|
||||
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
|
||||
if use_flash_attn
|
||||
else SelfAttention
|
||||
)
|
||||
inner_cross_attn_cls = (
|
||||
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
|
||||
if use_flash_attn
|
||||
else CrossAttention
|
||||
)
|
||||
self.inner_attn = inner_attn_cls(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
|
||||
@ -105,7 +105,7 @@ class GatedMlp(nn.Module):
|
||||
activation=F.sigmoid,
|
||||
bias1=True,
|
||||
bias2=True,
|
||||
multiple_of=256,
|
||||
multiple_of=128,
|
||||
return_residual=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -148,7 +148,7 @@ class ParallelGatedMlp(nn.Module):
|
||||
activation=F.sigmoid,
|
||||
bias1=True,
|
||||
bias2=True,
|
||||
multiple_of=256,
|
||||
multiple_of=128,
|
||||
sequence_parallel=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
|
||||
@ -2,8 +2,6 @@ import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
current_dir = Path(__file__).parent.absolute()
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
@ -20,16 +18,12 @@ from flash_attn.models.baichuan import (
|
||||
remap_state_dict_hf_baichuan,
|
||||
baichuan_config_to_gpt2_config,
|
||||
)
|
||||
from flash_attn.models.baichuan import (
|
||||
config_from_checkpoint,
|
||||
state_dicts_from_checkpoint,
|
||||
)
|
||||
from flash_attn.utils.distributed import all_gather_raw
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import update_graph_cache
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
|
||||
def test_baichuan_state_dict(model_name):
|
||||
config = baichuan_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
@ -45,7 +39,7 @@ def test_baichuan_state_dict(model_name):
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
|
||||
def test_baichuan_optimized(model_name):
|
||||
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
@ -122,7 +116,7 @@ def test_baichuan_optimized(model_name):
|
||||
|
||||
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
|
||||
def test_baichuan_parallel_forward(model_name, world_size):
|
||||
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
@ -217,7 +211,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
|
||||
def test_baichuan_generation(model_name):
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user