Support MPT with GQA (#1938)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Megha Agarwal 2023-12-12 10:16:05 -08:00 committed by GitHub
parent 7e1b21daac
commit 6428f1d051
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 6 deletions

View File

@ -138,7 +138,8 @@ class PagedAttention(nn.Module):
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, batch_size, seq_len, query.dtype)
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
# TODO(woosuk): Too many view operations. Let's try to reduce them
# in the future for code readability.
@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype)
bias = torch.arange(seq_len, dtype=dtype, device="cuda")
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
bias = bias.to(alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
batch_size,
alibi_slopes.shape[0],
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_bias = LowerTriangularMaskWithTensorBias(bias)
return attn_bias

View File

@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.head_dim = self.d_model // self.total_num_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
if "kv_n_heads" in config.attn_config:
self.total_num_kv_heads = config.attn_config['kv_n_heads']
else:
self.total_num_kv_heads = self.total_num_heads
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias,
linear_method=linear_method,
)
@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
def forward(
self,
@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)