From 6428f1d051d731a9eca192950591e7a5a2788cb2 Mon Sep 17 00:00:00 2001 From: Megha Agarwal <16129366+megha95@users.noreply.github.com> Date: Tue, 12 Dec 2023 10:16:05 -0800 Subject: [PATCH] Support MPT with GQA (#1938) Co-authored-by: Woosuk Kwon --- vllm/model_executor/layers/attention.py | 12 ++++++++---- vllm/model_executor/models/mpt.py | 22 ++++++++++++++++++++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index d0f0f28c..dd94dda4 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -138,7 +138,8 @@ class PagedAttention(nn.Module): input_metadata.attn_bias = attn_bias else: input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, batch_size, seq_len, query.dtype) + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) # TODO(woosuk): Too many view operations. Let's try to reduce them # in the future for code readability. @@ -180,31 +181,34 @@ class PagedAttention(nn.Module): def _make_alibi_bias( alibi_slopes: torch.Tensor, + num_kv_heads: int, batch_size: int, seq_len: int, dtype: torch.dtype, ) -> LowerTriangularMaskWithTensorBias: - bias = torch.arange(seq_len, dtype=dtype) + bias = torch.arange(seq_len, dtype=dtype, device="cuda") # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(prompt_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. bias = bias[None, :] - bias[:, None] - bias = bias.to(alibi_slopes.device) # When using custom attention bias, xformers requires the bias to # be sliced from a tensor whose length is a multiple of 8. padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] bias = torch.empty( batch_size, - alibi_slopes.shape[0], + num_heads, seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) attn_bias = LowerTriangularMaskWithTensorBias(bias) return attn_bias diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index c7be7a92..22b2a5c6 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -50,9 +50,14 @@ class MPTAttention(nn.Module): super().__init__() self.d_model = config.d_model self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads self.clip_qkv = config.attn_config["clip_qkv"] self.qk_ln = config.attn_config["qk_ln"] self.alibi_bias_max = config.attn_config["alibi_bias_max"] + if "kv_n_heads" in config.attn_config: + self.total_num_kv_heads = config.attn_config['kv_n_heads'] + else: + self.total_num_kv_heads = self.total_num_heads assert not config.attn_config["prefix_lm"] assert config.attn_config["alibi"] @@ -61,6 +66,7 @@ class MPTAttention(nn.Module): self.d_model, self.d_model // self.total_num_heads, self.total_num_heads, + self.total_num_kv_heads, bias=not config.no_bias, linear_method=linear_method, ) @@ -78,6 +84,17 @@ class MPTAttention(nn.Module): assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim # Create the alibi slopes and slice them. tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -91,7 +108,8 @@ class MPTAttention(nn.Module): self.attn = PagedAttention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads) def forward( self, @@ -105,7 +123,7 @@ class MPTAttention(nn.Module): qkv, _ = self.Wqkv(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.qk_ln: q = self.q_ln(q) k = self.k_ln(k)