diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 5a3eb85..3436745 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -623,22 +623,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): """ + input_ids: (batch, seqlen) int tensor inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 num_last_tokens: if > 0, only return the logits for the last n tokens """ + assert input_ids.ndim == 2, f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" + b, slen = input_ids.shape hidden_states = self.transformer( input_ids, position_ids=position_ids, inference_params=inference_params ) - if num_last_tokens > 0: - hidden_states = hidden_states[:, -num_last_tokens:] + assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" if self.project_out is not None: hidden_states = self.project_out(hidden_states) lm_logits = self.lm_head(hidden_states) # During inference, we want the full logit for sampling if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) - lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=hidden_states.shape[0]) + lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b) CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits) @@ -802,6 +804,8 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G assert config.hidden_size % world_size == 0 inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size assert inner_dim % world_size == 0 + assert config.hidden_size % config.n_head == 0 + headdim = config.hidden_size // config.n_head # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim. # vocab_size // world_size coordinates are nonzero. @@ -823,14 +827,6 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G ] state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...") else: - xs = [ - rearrange( - s[key], - "(nheadqkv headdim) ... -> nheadqkv headdim ...", - nheadqkv=n_head + 2 * n_head_kv, - ) - for s in state_dicts - ] n_head_each_rank = [ get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size) @@ -839,32 +835,41 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size) ] + xs = [ + rearrange( + s[key], + "(nheadqkv headdim) ... -> nheadqkv headdim ...", + nheadqkv=rank_n_head + 2 * rank_n_head_kv, + headdim=headdim, + ) + for s, rank_n_head, rank_n_head_kv in zip(state_dicts, n_head_each_rank, n_head_kv_each_rank) + ] + wq = torch.cat( + [x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0 + ) + wk = torch.cat( + [ + x[ + n_head_each_rank[rank] : n_head_each_rank[rank] + + n_head_kv_each_rank[rank] + ] + for rank, x in enumerate(xs) + ], + dim=0, + ) + wv = torch.cat( + [ + x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :] + for rank, x in enumerate(xs) + ], + dim=0, + ) + wqkv = torch.cat( + [wq, wk, wv], + dim=0, + ) state_dict[key] = rearrange( - torch.cat( - [ - torch.cat( - [x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0 - ), - torch.cat( - [ - x[ - n_head_each_rank[rank] : n_head_each_rank[rank] - + n_head_kv_each_rank[rank] - ] - for rank, x in enumerate(xs) - ], - dim=0, - ), - torch.cat( - [ - x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :] - for rank, x in enumerate(xs) - ], - dim=0, - ), - ], - dim=0, - ), + wqkv, "nheadqkv headdim ... -> (nheadqkv headdim) ...", ) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index d818ce4..3009fd1 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -735,7 +735,7 @@ class ParallelMHA(nn.Module): self.num_heads, self.world_size, self.local_rank ) self.num_heads_kv_per_rank = get_dim_for_local_rank( - self.num_heads, self.world_size, self.local_rank + self.num_heads_kv, self.world_size, self.local_rank ) self.head_dim = self.embed_dim // num_heads qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) @@ -758,7 +758,7 @@ class ParallelMHA(nn.Module): process_group, bias=qkv_proj_bias, sequence_parallel=sequence_parallel, - multiple_of=self.head_dim * 3, + multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank), **factory_kwargs, ) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index c69e9bd..93811df 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -3,7 +3,7 @@ import re import pytest import torch from einops import rearrange -from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2 +from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2, shard_state_dict_tp, combine_state_dicts_tp from flash_attn.utils.generation import InferenceParams from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import GPT2Config, GPT2Tokenizer @@ -444,3 +444,29 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): return_dict_in_generate=True, ) print(tokenizer.batch_decode(out_og.sequences)) + + +@pytest.mark.parametrize("n_heads_q_kv", [ + (8, 8), # Regular attention + (8, 4), # GQA + (8, 2), # MQA +]) +def test_gpt2_shard_unshard(n_heads_q_kv): + world_size = 2 + + config = GPT2Config.from_pretrained("gpt2") + config.vocab_size = 1024 + config.n_head, config.n_head_kv = n_heads_q_kv + model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16) + state_dict = model.state_dict() + shards = [ + # NOTE: Shallow copy as `state_dict` is modified in-place + shard_state_dict_tp(dict(state_dict), config, world_size, rank) + for rank in range(world_size) + ] + state_dict2 = combine_state_dicts_tp(shards, config) + assert state_dict2.keys() == state_dict.keys() + for k in state_dict.keys(): + ref = state_dict[k] + new = state_dict[k] + assert torch.allclose(ref, new, atol=0.0, rtol=0.0)