diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 8955f31..0cf9149 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -20,16 +20,12 @@ from flash_attn.models.opt import remap_state_dict_hf_opt from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import ( - FusedMLP, - GatedMlp, - Mlp, - ParallelFusedMLP, - ParallelGatedMlp, - ParallelMLP, -) +from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP, + ParallelGatedMlp, ParallelMLP) from flash_attn.ops.activations import sqrelu_fwd -from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params +from flash_attn.utils.distributed import (all_gather_raw, + get_dim_for_local_rank, + sync_shared_params) from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.pretrained import state_dict_from_pretrained @@ -44,7 +40,8 @@ except ImportError: dropout_add_layer_norm = None try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual + from flash_attn.ops.layer_norm import \ + dropout_add_layer_norm_parallel_residual except ImportError: dropout_add_layer_norm_parallel_residual = None @@ -673,6 +670,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): def shard_state_dict_tp(state_dict, config, world_size, rank): """Convert the state_dict of a standard GPT model to the state_dict of a GPT model with tensor parallel. + + This function modifies state_dict in place. """ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple @@ -784,11 +783,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): return state_dict -def combine_state_dicts_tp(state_dicts, config): - """Convert the state_dict of a GPT model with tensor parallel to the state_dict of a - standard GPT model. +def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config): + """Convert the list of sharded state_dict of a GPT model with tensor parallel to + the state_dict of a standard GPT model. This function is meant to be the "reverse" of shard_state_dict_tp. + + Precondition: + - state_dicts should be ordered in the same way as the shards were created. """ world_size = len(state_dicts) keys = state_dicts[0].keys() @@ -812,9 +814,6 @@ def combine_state_dicts_tp(state_dicts, config): def combine_qkv_headdim(state_dicts, state_dict, key): n_head = config.n_head n_head_kv = getattr(config, "n_head_kv", n_head) - assert n_head % world_size == 0 and n_head_kv % world_size == 0 - n_head_per_rank = n_head // world_size - n_head_kv_per_rank = n_head_kv // world_size if key in state_dict: if n_head_kv == n_head: xs = [ @@ -830,18 +829,37 @@ def combine_state_dicts_tp(state_dicts, config): ) for s in state_dicts ] + n_head_each_rank = [ + get_dim_for_local_rank(n_head, world_size, local_rank) + for local_rank in range(world_size) + ] + n_head_kv_each_rank = [ + get_dim_for_local_rank(n_head_kv, world_size, local_rank) + for local_rank in range(world_size) + ] state_dict[key] = rearrange( torch.cat( [ - torch.cat([x[:n_head_per_rank] for x in xs], dim=0), + torch.cat( + [x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0 + ), torch.cat( [ - x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank] - for x in xs + x[ + n_head_each_rank[rank] : n_head_each_rank[rank] + + n_head_kv_each_rank[rank] + ] + for rank, x in enumerate(xs) + ], + dim=0, + ), + torch.cat( + [ + x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :] + for rank, x in enumerate(xs) ], dim=0, ), - torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0), ], dim=0, ),