handle uneven heads across ranks when combining state_dicts; resolves #467 (#468)

* q

* add comment.
This commit is contained in:
Xuechen Li 2023-08-20 14:57:34 -07:00 committed by GitHub
parent d431f16751
commit 25d6b1dbcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,16 +20,12 @@ from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import (
FusedMLP,
GatedMlp,
Mlp,
ParallelFusedMLP,
ParallelGatedMlp,
ParallelMLP,
)
from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP,
ParallelGatedMlp, ParallelMLP)
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
from flash_attn.utils.distributed import (all_gather_raw,
get_dim_for_local_rank,
sync_shared_params)
from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained
@ -44,7 +40,8 @@ except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
from flash_attn.ops.layer_norm import \
dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
@ -673,6 +670,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
This function modifies state_dict in place.
"""
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
@ -784,11 +783,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
return state_dict
def combine_state_dicts_tp(state_dicts, config):
"""Convert the state_dict of a GPT model with tensor parallel to the state_dict of a
standard GPT model.
def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
the state_dict of a standard GPT model.
This function is meant to be the "reverse" of shard_state_dict_tp.
Precondition:
- state_dicts should be ordered in the same way as the shards were created.
"""
world_size = len(state_dicts)
keys = state_dicts[0].keys()
@ -812,9 +814,6 @@ def combine_state_dicts_tp(state_dicts, config):
def combine_qkv_headdim(state_dicts, state_dict, key):
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
if key in state_dict:
if n_head_kv == n_head:
xs = [
@ -830,18 +829,37 @@ def combine_state_dicts_tp(state_dicts, config):
)
for s in state_dicts
]
n_head_each_rank = [
get_dim_for_local_rank(n_head, world_size, local_rank)
for local_rank in range(world_size)
]
n_head_kv_each_rank = [
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
for local_rank in range(world_size)
]
state_dict[key] = rearrange(
torch.cat(
[
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat(
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
),
torch.cat(
[
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
for x in xs
x[
n_head_each_rank[rank] : n_head_each_rank[rank]
+ n_head_kv_each_rank[rank]
]
for rank, x in enumerate(xs)
],
dim=0,
),
torch.cat(
[
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
for rank, x in enumerate(xs)
],
dim=0,
),
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
],
dim=0,
),