From c3b219665292c61a51153d0ded4473c494296382 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 21 Dec 2023 22:49:55 -0800 Subject: [PATCH] Add Alibi to MHA, test with Baichuan-13B --- flash_attn/models/baichuan.py | 27 ++++-------- flash_attn/models/gpt.py | 2 + flash_attn/modules/mha.py | 78 +++++++++++++++++++++++++++++++---- flash_attn/modules/mlp.py | 4 +- tests/models/test_baichuan.py | 14 ++----- 5 files changed, 84 insertions(+), 41 deletions(-) diff --git a/flash_attn/models/baichuan.py b/flash_attn/models/baichuan.py index 7fb53e4..2ca9ac1 100644 --- a/flash_attn/models/baichuan.py +++ b/flash_attn/models/baichuan.py @@ -109,29 +109,14 @@ def remap_state_dict_hf_baichuan(state_dict, config): state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) for l in range(config.n_layer): # pop rotary_emb.inv_freq from state dict - state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq") + state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None) return state_dict -def config_from_checkpoint(checkpoint_path: str, model_name: str) -> PretrainedConfig: - """Load a BaiChuanConfig from a checkpoint path.""" - config = AutoConfig.from_pretrained( - Path(checkpoint_path) / model_name, trust_remote_code=True - ) - return config - - -def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict: - # Need to sort, otherwise we mess up the ordering and the weights are wrong - return [ - torch.load(path, map_location="cpu") - for path in sorted( - (Path(checkpoint_path) / model_name).glob("pytorch_model*.bin") - ) - ] - - def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config: + # HACK: the config doesn't have say whether it's rotary or alibi. + # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi). + use_rotary = baichuan_config.hidden_size < 5000 return GPT2Config( vocab_size=baichuan_config.vocab_size, n_positions=0, # No absolute position embedding @@ -151,8 +136,10 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con # These are new arguments not in the original GPT2Config pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything rms_norm=True, - rotary_emb_fraction=1.0, + rotary_emb_fraction=1.0 if use_rotary else 0.0, rotary_emb_interleaved=False, + use_alibi=not use_rotary, + use_flash_attn=not use_rotary, # Alibi code path requires flash_attn tie_word_embeddings=False, qkv_proj_bias=False, out_proj_bias=False, diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index b2403dc..97d555d 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -85,6 +85,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0) rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) + use_alibi = getattr(config, "use_alibi", False) use_flash_attn = getattr(config, "use_flash_attn", False) fused_bias_fc = getattr(config, "fused_bias_fc", False) if not fused_bias_fc: @@ -116,6 +117,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt rotary_emb_base=rotary_emb_base, rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_interleaved=rotary_emb_interleaved, + use_alibi=use_alibi, use_flash_attn=use_flash_attn, **serial_kwargs, **parallel_kwargs, diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 976bd3d..f4c2edc 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -33,6 +33,23 @@ except ImportError: RotaryEmbedding = None +# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 +def get_alibi_slopes(nheads): + def get_slopes_power_of_2(nheads): + start = 2 ** (-(2 ** -(math.log2(nheads) - 3))) + ratio = start + return [start * ratio**i for i in range(nheads)] + + if math.log2(nheads).is_integer(): + return get_slopes_power_of_2(nheads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(nheads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2] + ) + + class FlashSelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments @@ -44,13 +61,14 @@ class FlashSelfAttention(nn.Module): (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None): super().__init__() assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) + self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): """Implements the multihead softmax attention. @@ -84,6 +102,7 @@ class FlashSelfAttention(nn.Module): self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, + alibi_slopes=self.alibi_slopes, ) else: return flash_attn_qkvpacked_func( @@ -91,6 +110,7 @@ class FlashSelfAttention(nn.Module): self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, + alibi_slopes=self.alibi_slopes, ) @@ -105,13 +125,14 @@ class FlashCrossAttention(nn.Module): (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None): super().__init__() assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) + self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) def forward( self, @@ -158,6 +179,7 @@ class FlashCrossAttention(nn.Module): self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, + alibi_slopes=self.alibi_slopes, ) else: batch_size, seqlen_q = q.shape[0], q.shape[1] @@ -169,6 +191,7 @@ class FlashCrossAttention(nn.Module): self.drop.p if self.training else 0.0, causal=causal, softmax_scale=self.softmax_scale, + alibi_slopes=self.alibi_slopes, ) @@ -315,8 +338,8 @@ def _update_kv_cache(kv, inference_params, layer_idx): batch_end = batch_start + kv.shape[0] sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] - assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) - assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) + assert batch_end <= kv_cache.shape[0] + assert sequence_end <= kv_cache.shape[1] assert kv_cache is not None kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv return kv_cache[batch_start:batch_end, :sequence_end, ...] @@ -342,6 +365,7 @@ class MHA(nn.Module): rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, + use_alibi=False, fused_bias_fc=False, use_flash_attn=False, return_residual=False, @@ -366,6 +390,11 @@ class MHA(nn.Module): self.use_flash_attn = use_flash_attn self.return_residual = return_residual self.checkpointing = checkpointing + if use_alibi: + assert use_flash_attn, "ALiBi code path requires flash_attn" + alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) + else: + alibi_slopes = None self.num_heads = num_heads self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads @@ -395,8 +424,16 @@ class MHA(nn.Module): LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) ) wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls - inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention - inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + inner_attn_cls = ( + partial(FlashSelfAttention, alibi_slopes=alibi_slopes) + if use_flash_attn + else SelfAttention + ) + inner_cross_attn_cls = ( + partial(FlashCrossAttention, alibi_slopes=alibi_slopes) + if use_flash_attn + else CrossAttention + ) if not self.cross_attn: self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) else: @@ -413,7 +450,9 @@ class MHA(nn.Module): ) self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) self.inner_attn = inner_attn_cls( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + causal=causal, + softmax_scale=softmax_scale, + attention_dropout=dropout, ) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout @@ -672,6 +711,7 @@ class ParallelMHA(nn.Module): rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, + use_alibi=False, use_flash_attn=False, checkpointing=False, sequence_parallel=True, @@ -707,6 +747,18 @@ class ParallelMHA(nn.Module): self.head_dim = self.embed_dim // num_heads qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + if use_alibi: + assert use_flash_attn, "ALiBi code path requires flash_attn" + num_heads_local = math.ceil(self.num_heads / self.world_size) + alibi_slopes = torch.tensor( + get_alibi_slopes(num_heads)[ + self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local + ], + device=device, + ) + else: + alibi_slopes = None + if self.rotary_emb_dim > 0: assert RotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = RotaryEmbedding( @@ -728,8 +780,16 @@ class ParallelMHA(nn.Module): multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), **factory_kwargs, ) - inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention - inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + inner_attn_cls = ( + partial(FlashSelfAttention, alibi_slopes=alibi_slopes) + if use_flash_attn + else SelfAttention + ) + inner_cross_attn_cls = ( + partial(FlashCrossAttention, alibi_slopes=alibi_slopes) + if use_flash_attn + else CrossAttention + ) self.inner_attn = inner_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 8a65b22..23584d3 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -105,7 +105,7 @@ class GatedMlp(nn.Module): activation=F.sigmoid, bias1=True, bias2=True, - multiple_of=256, + multiple_of=128, return_residual=False, device=None, dtype=None, @@ -148,7 +148,7 @@ class ParallelGatedMlp(nn.Module): activation=F.sigmoid, bias1=True, bias2=True, - multiple_of=256, + multiple_of=128, sequence_parallel=True, device=None, dtype=None, diff --git a/tests/models/test_baichuan.py b/tests/models/test_baichuan.py index c2cc2ec..4f04c2c 100644 --- a/tests/models/test_baichuan.py +++ b/tests/models/test_baichuan.py @@ -2,8 +2,6 @@ import os import time from pathlib import Path -current_dir = Path(__file__).parent.absolute() - import torch import pytest @@ -20,16 +18,12 @@ from flash_attn.models.baichuan import ( remap_state_dict_hf_baichuan, baichuan_config_to_gpt2_config, ) -from flash_attn.models.baichuan import ( - config_from_checkpoint, - state_dicts_from_checkpoint, -) from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import update_graph_cache -@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"]) +@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]) def test_baichuan_state_dict(model_name): config = baichuan_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) @@ -45,7 +39,7 @@ def test_baichuan_state_dict(model_name): assert state_dict[k].shape == pretrained_state_dict[k].shape -@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"]) +@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]) def test_baichuan_optimized(model_name): """Check that our implementation of Baichuan (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF @@ -122,7 +116,7 @@ def test_baichuan_optimized(model_name): # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward" @pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"]) +@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]) def test_baichuan_parallel_forward(model_name, world_size): """Check that our implementation of Baichuan (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF @@ -217,7 +211,7 @@ def test_baichuan_parallel_forward(model_name, world_size): ).abs().max().item() -@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"]) +@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]) def test_baichuan_generation(model_name): dtype = torch.float16 device = "cuda"