Support MPT with GQA (#1938)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
7e1b21daac
commit
6428f1d051
@ -138,7 +138,8 @@ class PagedAttention(nn.Module):
|
|||||||
input_metadata.attn_bias = attn_bias
|
input_metadata.attn_bias = attn_bias
|
||||||
else:
|
else:
|
||||||
input_metadata.attn_bias = _make_alibi_bias(
|
input_metadata.attn_bias = _make_alibi_bias(
|
||||||
self.alibi_slopes, batch_size, seq_len, query.dtype)
|
self.alibi_slopes, self.num_kv_heads, batch_size,
|
||||||
|
seq_len, query.dtype)
|
||||||
|
|
||||||
# TODO(woosuk): Too many view operations. Let's try to reduce them
|
# TODO(woosuk): Too many view operations. Let's try to reduce them
|
||||||
# in the future for code readability.
|
# in the future for code readability.
|
||||||
@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
def _make_alibi_bias(
|
def _make_alibi_bias(
|
||||||
alibi_slopes: torch.Tensor,
|
alibi_slopes: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> LowerTriangularMaskWithTensorBias:
|
) -> LowerTriangularMaskWithTensorBias:
|
||||||
bias = torch.arange(seq_len, dtype=dtype)
|
bias = torch.arange(seq_len, dtype=dtype, device="cuda")
|
||||||
# NOTE(zhuohan): HF uses
|
# NOTE(zhuohan): HF uses
|
||||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||||
# here. We find that both biases give the same results, but
|
# here. We find that both biases give the same results, but
|
||||||
# the bias below more accurately follows the original ALiBi
|
# the bias below more accurately follows the original ALiBi
|
||||||
# paper.
|
# paper.
|
||||||
bias = bias[None, :] - bias[:, None]
|
bias = bias[None, :] - bias[:, None]
|
||||||
bias = bias.to(alibi_slopes.device)
|
|
||||||
|
|
||||||
# When using custom attention bias, xformers requires the bias to
|
# When using custom attention bias, xformers requires the bias to
|
||||||
# be sliced from a tensor whose length is a multiple of 8.
|
# be sliced from a tensor whose length is a multiple of 8.
|
||||||
padded_len = (seq_len + 7) // 8 * 8
|
padded_len = (seq_len + 7) // 8 * 8
|
||||||
|
num_heads = alibi_slopes.shape[0]
|
||||||
bias = torch.empty(
|
bias = torch.empty(
|
||||||
batch_size,
|
batch_size,
|
||||||
alibi_slopes.shape[0],
|
num_heads,
|
||||||
seq_len,
|
seq_len,
|
||||||
padded_len,
|
padded_len,
|
||||||
device=alibi_slopes.device,
|
device=alibi_slopes.device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)[:, :, :, :seq_len].copy_(bias)
|
)[:, :, :, :seq_len].copy_(bias)
|
||||||
bias.mul_(alibi_slopes[:, None, None])
|
bias.mul_(alibi_slopes[:, None, None])
|
||||||
|
if num_heads != num_kv_heads:
|
||||||
|
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||||
return attn_bias
|
return attn_bias
|
||||||
|
|
||||||
|
|||||||
@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
self.total_num_heads = config.n_heads
|
self.total_num_heads = config.n_heads
|
||||||
|
self.head_dim = self.d_model // self.total_num_heads
|
||||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||||
self.qk_ln = config.attn_config["qk_ln"]
|
self.qk_ln = config.attn_config["qk_ln"]
|
||||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||||
|
if "kv_n_heads" in config.attn_config:
|
||||||
|
self.total_num_kv_heads = config.attn_config['kv_n_heads']
|
||||||
|
else:
|
||||||
|
self.total_num_kv_heads = self.total_num_heads
|
||||||
assert not config.attn_config["prefix_lm"]
|
assert not config.attn_config["prefix_lm"]
|
||||||
assert config.attn_config["alibi"]
|
assert config.attn_config["alibi"]
|
||||||
|
|
||||||
@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
|
|||||||
self.d_model,
|
self.d_model,
|
||||||
self.d_model // self.total_num_heads,
|
self.d_model // self.total_num_heads,
|
||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
bias=not config.no_bias,
|
bias=not config.no_bias,
|
||||||
linear_method=linear_method,
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
|
|||||||
assert self.total_num_heads % tp_world_size == 0
|
assert self.total_num_heads % tp_world_size == 0
|
||||||
self.num_heads = self.total_num_heads // tp_world_size
|
self.num_heads = self.total_num_heads // tp_world_size
|
||||||
|
|
||||||
|
if self.total_num_kv_heads >= tp_world_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_world_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_world_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
# Create the alibi slopes and slice them.
|
# Create the alibi slopes and slice them.
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
head_start = tp_rank * self.num_heads
|
head_start = tp_rank * self.num_heads
|
||||||
@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
|
|||||||
self.attn = PagedAttention(self.num_heads,
|
self.attn = PagedAttention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scaling,
|
scaling,
|
||||||
alibi_slopes=alibi_slopes)
|
alibi_slopes=alibi_slopes,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
|
|||||||
qkv, _ = self.Wqkv(hidden_states)
|
qkv, _ = self.Wqkv(hidden_states)
|
||||||
if self.clip_qkv is not None:
|
if self.clip_qkv is not None:
|
||||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
if self.qk_ln:
|
if self.qk_ln:
|
||||||
q = self.q_ln(q)
|
q = self.q_ln(q)
|
||||||
k = self.k_ln(k)
|
k = self.k_ln(k)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user