Pass alibi slopes to flash_attn_with_kvcache during generation
This commit is contained in:
parent
f844852485
commit
3f7d5786ba
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2023, GGGGGGXY.
|
||||
# Copyright (c) 2023, GGGGGGXY, Tri Dao.
|
||||
|
||||
import math
|
||||
import json
|
||||
@ -14,7 +14,6 @@ from einops import rearrange
|
||||
from transformers import GPT2Config, AutoConfig, PretrainedConfig
|
||||
|
||||
|
||||
# only support Baichuan-7B now
|
||||
def remap_state_dict_hf_baichuan(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r"^model.", "transformer.", key)
|
||||
|
||||
@ -501,6 +501,7 @@ class MHA(nn.Module):
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
||||
context = flash_attn_with_kvcache(
|
||||
q,
|
||||
kv_cache[:, :, 0],
|
||||
@ -513,6 +514,7 @@ class MHA(nn.Module):
|
||||
softmax_scale=self.inner_cross_attn.softmax_scale,
|
||||
causal=self.inner_cross_attn.causal,
|
||||
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
return context
|
||||
|
||||
@ -534,6 +536,7 @@ class MHA(nn.Module):
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
||||
return flash_attn_with_kvcache(
|
||||
q,
|
||||
kv_cache[:, :, 0],
|
||||
@ -543,6 +546,7 @@ class MHA(nn.Module):
|
||||
cache_seqlens=cache_seqlens,
|
||||
softmax_scale=self.inner_cross_attn.softmax_scale,
|
||||
causal=self.inner_cross_attn.causal,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -847,6 +851,7 @@ class ParallelMHA(nn.Module):
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
||||
context = flash_attn_with_kvcache(
|
||||
q,
|
||||
kv_cache[:, :, 0],
|
||||
@ -859,6 +864,7 @@ class ParallelMHA(nn.Module):
|
||||
softmax_scale=self.inner_cross_attn.softmax_scale,
|
||||
causal=self.inner_cross_attn.causal,
|
||||
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
return context
|
||||
|
||||
@ -876,6 +882,7 @@ class ParallelMHA(nn.Module):
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
||||
context = flash_attn_with_kvcache(
|
||||
q,
|
||||
kv_cache[:, :, 0],
|
||||
@ -885,6 +892,7 @@ class ParallelMHA(nn.Module):
|
||||
cache_seqlens=cache_seqlens,
|
||||
softmax_scale=self.inner_cross_attn.softmax_scale,
|
||||
causal=self.inner_cross_attn.causal,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
return context
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
Loading…
Reference in New Issue
Block a user