Pass alibi slopes to flash_attn_with_kvcache during generation

This commit is contained in:
Tri Dao 2023-12-24 20:31:59 -08:00
parent f844852485
commit 3f7d5786ba
3 changed files with 10 additions and 2 deletions

View File

@ -1,4 +1,4 @@
# Copyright (c) 2023, GGGGGGXY.
# Copyright (c) 2023, GGGGGGXY, Tri Dao.
import math
import json
@ -14,7 +14,6 @@ from einops import rearrange
from transformers import GPT2Config, AutoConfig, PretrainedConfig
# only support Baichuan-7B now
def remap_state_dict_hf_baichuan(state_dict, config):
def key_mapping_layers(key):
return re.sub(r"^model.", "transformer.", key)

View File

@ -501,6 +501,7 @@ class MHA(nn.Module):
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
@ -513,6 +514,7 @@ class MHA(nn.Module):
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
)
return context
@ -534,6 +536,7 @@ class MHA(nn.Module):
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
return flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
@ -543,6 +546,7 @@ class MHA(nn.Module):
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
alibi_slopes=alibi_slopes,
)
def forward(
@ -847,6 +851,7 @@ class ParallelMHA(nn.Module):
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
@ -859,6 +864,7 @@ class ParallelMHA(nn.Module):
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
alibi_slopes=alibi_slopes,
)
return context
@ -876,6 +882,7 @@ class ParallelMHA(nn.Module):
if inference_params.lengths_per_sample is not None
else inference_params.seqlen_offset
)
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
@ -885,6 +892,7 @@ class ParallelMHA(nn.Module):
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
alibi_slopes=alibi_slopes,
)
return context

View File

@ -1,3 +1,4 @@
# Copyright (c) 2023, Tri Dao.
import os
import time
from pathlib import Path