Support MQA + MP for decoding (#490)

Co-authored-by: danthe3rd <danthe3rd>
This commit is contained in:
dan_the_3rd 2023-08-30 19:29:54 +02:00 committed by GitHub
parent 0cb595ad94
commit 011ec323d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 39 deletions

View File

@ -623,22 +623,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
"""
input_ids: (batch, seqlen) int tensor
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
assert input_ids.ndim == 2, f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
b, slen = input_ids.shape
hidden_states = self.transformer(
input_ids, position_ids=position_ids, inference_params=inference_params
)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)
# During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=hidden_states.shape[0])
lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
@ -802,6 +804,8 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
assert config.hidden_size % config.n_head == 0
headdim = config.hidden_size // config.n_head
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
# vocab_size // world_size coordinates are nonzero.
@ -823,14 +827,6 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
]
state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
else:
xs = [
rearrange(
s[key],
"(nheadqkv headdim) ... -> nheadqkv headdim ...",
nheadqkv=n_head + 2 * n_head_kv,
)
for s in state_dicts
]
n_head_each_rank = [
get_dim_for_local_rank(n_head, world_size, local_rank)
for local_rank in range(world_size)
@ -839,32 +835,41 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
for local_rank in range(world_size)
]
xs = [
rearrange(
s[key],
"(nheadqkv headdim) ... -> nheadqkv headdim ...",
nheadqkv=rank_n_head + 2 * rank_n_head_kv,
headdim=headdim,
)
for s, rank_n_head, rank_n_head_kv in zip(state_dicts, n_head_each_rank, n_head_kv_each_rank)
]
wq = torch.cat(
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
)
wk = torch.cat(
[
x[
n_head_each_rank[rank] : n_head_each_rank[rank]
+ n_head_kv_each_rank[rank]
]
for rank, x in enumerate(xs)
],
dim=0,
)
wv = torch.cat(
[
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
for rank, x in enumerate(xs)
],
dim=0,
)
wqkv = torch.cat(
[wq, wk, wv],
dim=0,
)
state_dict[key] = rearrange(
torch.cat(
[
torch.cat(
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
),
torch.cat(
[
x[
n_head_each_rank[rank] : n_head_each_rank[rank]
+ n_head_kv_each_rank[rank]
]
for rank, x in enumerate(xs)
],
dim=0,
),
torch.cat(
[
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
for rank, x in enumerate(xs)
],
dim=0,
),
],
dim=0,
),
wqkv,
"nheadqkv headdim ... -> (nheadqkv headdim) ...",
)

View File

@ -735,7 +735,7 @@ class ParallelMHA(nn.Module):
self.num_heads, self.world_size, self.local_rank
)
self.num_heads_kv_per_rank = get_dim_for_local_rank(
self.num_heads, self.world_size, self.local_rank
self.num_heads_kv, self.world_size, self.local_rank
)
self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
@ -758,7 +758,7 @@ class ParallelMHA(nn.Module):
process_group,
bias=qkv_proj_bias,
sequence_parallel=sequence_parallel,
multiple_of=self.head_dim * 3,
multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank),
**factory_kwargs,
)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention

View File

@ -3,7 +3,7 @@ import re
import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2, shard_state_dict_tp, combine_state_dicts_tp
from flash_attn.utils.generation import InferenceParams
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer
@ -444,3 +444,29 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
return_dict_in_generate=True,
)
print(tokenizer.batch_decode(out_og.sequences))
@pytest.mark.parametrize("n_heads_q_kv", [
(8, 8), # Regular attention
(8, 4), # GQA
(8, 2), # MQA
])
def test_gpt2_shard_unshard(n_heads_q_kv):
world_size = 2
config = GPT2Config.from_pretrained("gpt2")
config.vocab_size = 1024
config.n_head, config.n_head_kv = n_heads_q_kv
model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16)
state_dict = model.state_dict()
shards = [
# NOTE: Shallow copy as `state_dict` is modified in-place
shard_state_dict_tp(dict(state_dict), config, world_size, rank)
for rank in range(world_size)
]
state_dict2 = combine_state_dicts_tp(shards, config)
assert state_dict2.keys() == state_dict.keys()
for k in state_dict.keys():
ref = state_dict[k]
new = state_dict[k]
assert torch.allclose(ref, new, atol=0.0, rtol=0.0)