Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
This commit is contained in:
parent
0cb595ad94
commit
011ec323d6
@ -623,22 +623,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
||||
"""
|
||||
input_ids: (batch, seqlen) int tensor
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
assert input_ids.ndim == 2, f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
|
||||
b, slen = input_ids.shape
|
||||
hidden_states = self.transformer(
|
||||
input_ids, position_ids=position_ids, inference_params=inference_params
|
||||
)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
# During inference, we want the full logit for sampling
|
||||
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
||||
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
|
||||
lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=hidden_states.shape[0])
|
||||
lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
@ -802,6 +804,8 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
|
||||
assert config.hidden_size % world_size == 0
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
assert inner_dim % world_size == 0
|
||||
assert config.hidden_size % config.n_head == 0
|
||||
headdim = config.hidden_size // config.n_head
|
||||
|
||||
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
|
||||
# vocab_size // world_size coordinates are nonzero.
|
||||
@ -823,14 +827,6 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
|
||||
]
|
||||
state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
|
||||
else:
|
||||
xs = [
|
||||
rearrange(
|
||||
s[key],
|
||||
"(nheadqkv headdim) ... -> nheadqkv headdim ...",
|
||||
nheadqkv=n_head + 2 * n_head_kv,
|
||||
)
|
||||
for s in state_dicts
|
||||
]
|
||||
n_head_each_rank = [
|
||||
get_dim_for_local_rank(n_head, world_size, local_rank)
|
||||
for local_rank in range(world_size)
|
||||
@ -839,32 +835,41 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
|
||||
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
|
||||
for local_rank in range(world_size)
|
||||
]
|
||||
xs = [
|
||||
rearrange(
|
||||
s[key],
|
||||
"(nheadqkv headdim) ... -> nheadqkv headdim ...",
|
||||
nheadqkv=rank_n_head + 2 * rank_n_head_kv,
|
||||
headdim=headdim,
|
||||
)
|
||||
for s, rank_n_head, rank_n_head_kv in zip(state_dicts, n_head_each_rank, n_head_kv_each_rank)
|
||||
]
|
||||
wq = torch.cat(
|
||||
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
|
||||
)
|
||||
wk = torch.cat(
|
||||
[
|
||||
x[
|
||||
n_head_each_rank[rank] : n_head_each_rank[rank]
|
||||
+ n_head_kv_each_rank[rank]
|
||||
]
|
||||
for rank, x in enumerate(xs)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
wv = torch.cat(
|
||||
[
|
||||
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
|
||||
for rank, x in enumerate(xs)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
wqkv = torch.cat(
|
||||
[wq, wk, wv],
|
||||
dim=0,
|
||||
)
|
||||
state_dict[key] = rearrange(
|
||||
torch.cat(
|
||||
[
|
||||
torch.cat(
|
||||
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
|
||||
),
|
||||
torch.cat(
|
||||
[
|
||||
x[
|
||||
n_head_each_rank[rank] : n_head_each_rank[rank]
|
||||
+ n_head_kv_each_rank[rank]
|
||||
]
|
||||
for rank, x in enumerate(xs)
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
torch.cat(
|
||||
[
|
||||
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
|
||||
for rank, x in enumerate(xs)
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
wqkv,
|
||||
"nheadqkv headdim ... -> (nheadqkv headdim) ...",
|
||||
)
|
||||
|
||||
|
||||
@ -735,7 +735,7 @@ class ParallelMHA(nn.Module):
|
||||
self.num_heads, self.world_size, self.local_rank
|
||||
)
|
||||
self.num_heads_kv_per_rank = get_dim_for_local_rank(
|
||||
self.num_heads, self.world_size, self.local_rank
|
||||
self.num_heads_kv, self.world_size, self.local_rank
|
||||
)
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||
@ -758,7 +758,7 @@ class ParallelMHA(nn.Module):
|
||||
process_group,
|
||||
bias=qkv_proj_bias,
|
||||
sequence_parallel=sequence_parallel,
|
||||
multiple_of=self.head_dim * 3,
|
||||
multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank),
|
||||
**factory_kwargs,
|
||||
)
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
|
||||
@ -3,7 +3,7 @@ import re
|
||||
import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
|
||||
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2, shard_state_dict_tp, combine_state_dicts_tp
|
||||
from flash_attn.utils.generation import InferenceParams
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from transformers import GPT2Config, GPT2Tokenizer
|
||||
@ -444,3 +444,29 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
print(tokenizer.batch_decode(out_og.sequences))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_heads_q_kv", [
|
||||
(8, 8), # Regular attention
|
||||
(8, 4), # GQA
|
||||
(8, 2), # MQA
|
||||
])
|
||||
def test_gpt2_shard_unshard(n_heads_q_kv):
|
||||
world_size = 2
|
||||
|
||||
config = GPT2Config.from_pretrained("gpt2")
|
||||
config.vocab_size = 1024
|
||||
config.n_head, config.n_head_kv = n_heads_q_kv
|
||||
model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16)
|
||||
state_dict = model.state_dict()
|
||||
shards = [
|
||||
# NOTE: Shallow copy as `state_dict` is modified in-place
|
||||
shard_state_dict_tp(dict(state_dict), config, world_size, rank)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
state_dict2 = combine_state_dicts_tp(shards, config)
|
||||
assert state_dict2.keys() == state_dict.keys()
|
||||
for k in state_dict.keys():
|
||||
ref = state_dict[k]
|
||||
new = state_dict[k]
|
||||
assert torch.allclose(ref, new, atol=0.0, rtol=0.0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user