diff --git a/flash_attn/models/baichuan.py b/flash_attn/models/baichuan.py index be7320b..97d0307 100644 --- a/flash_attn/models/baichuan.py +++ b/flash_attn/models/baichuan.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, GGGGGGXY. +# Copyright (c) 2023, GGGGGGXY, Tri Dao. import math import json @@ -14,7 +14,6 @@ from einops import rearrange from transformers import GPT2Config, AutoConfig, PretrainedConfig -# only support Baichuan-7B now def remap_state_dict_hf_baichuan(state_dict, config): def key_mapping_layers(key): return re.sub(r"^model.", "transformer.", key) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index f4c2edc..16c245c 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -501,6 +501,7 @@ class MHA(nn.Module): if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], @@ -513,6 +514,7 @@ class MHA(nn.Module): softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + alibi_slopes=alibi_slopes, ) return context @@ -534,6 +536,7 @@ class MHA(nn.Module): if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) return flash_attn_with_kvcache( q, kv_cache[:, :, 0], @@ -543,6 +546,7 @@ class MHA(nn.Module): cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, + alibi_slopes=alibi_slopes, ) def forward( @@ -847,6 +851,7 @@ class ParallelMHA(nn.Module): if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], @@ -859,6 +864,7 @@ class ParallelMHA(nn.Module): softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + alibi_slopes=alibi_slopes, ) return context @@ -876,6 +882,7 @@ class ParallelMHA(nn.Module): if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], @@ -885,6 +892,7 @@ class ParallelMHA(nn.Module): cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, + alibi_slopes=alibi_slopes, ) return context diff --git a/tests/models/test_baichuan.py b/tests/models/test_baichuan.py index 1fc550a..1d2964b 100644 --- a/tests/models/test_baichuan.py +++ b/tests/models/test_baichuan.py @@ -1,3 +1,4 @@ +# Copyright (c) 2023, Tri Dao. import os import time from pathlib import Path