* q * add comment.
This commit is contained in:
parent
d431f16751
commit
25d6b1dbcb
@ -20,16 +20,12 @@ from flash_attn.models.opt import remap_state_dict_hf_opt
|
||||
from flash_attn.modules.block import Block, ParallelBlock
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import (
|
||||
FusedMLP,
|
||||
GatedMlp,
|
||||
Mlp,
|
||||
ParallelFusedMLP,
|
||||
ParallelGatedMlp,
|
||||
ParallelMLP,
|
||||
)
|
||||
from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP,
|
||||
ParallelGatedMlp, ParallelMLP)
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
|
||||
from flash_attn.utils.distributed import (all_gather_raw,
|
||||
get_dim_for_local_rank,
|
||||
sync_shared_params)
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
@ -44,7 +40,8 @@ except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
|
||||
from flash_attn.ops.layer_norm import \
|
||||
dropout_add_layer_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_layer_norm_parallel_residual = None
|
||||
|
||||
@ -673,6 +670,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
||||
with tensor parallel.
|
||||
|
||||
This function modifies state_dict in place.
|
||||
"""
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
@ -784,11 +783,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
return state_dict
|
||||
|
||||
|
||||
def combine_state_dicts_tp(state_dicts, config):
|
||||
"""Convert the state_dict of a GPT model with tensor parallel to the state_dict of a
|
||||
standard GPT model.
|
||||
def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
|
||||
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
|
||||
the state_dict of a standard GPT model.
|
||||
|
||||
This function is meant to be the "reverse" of shard_state_dict_tp.
|
||||
|
||||
Precondition:
|
||||
- state_dicts should be ordered in the same way as the shards were created.
|
||||
"""
|
||||
world_size = len(state_dicts)
|
||||
keys = state_dicts[0].keys()
|
||||
@ -812,9 +814,6 @@ def combine_state_dicts_tp(state_dicts, config):
|
||||
def combine_qkv_headdim(state_dicts, state_dict, key):
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, "n_head_kv", n_head)
|
||||
assert n_head % world_size == 0 and n_head_kv % world_size == 0
|
||||
n_head_per_rank = n_head // world_size
|
||||
n_head_kv_per_rank = n_head_kv // world_size
|
||||
if key in state_dict:
|
||||
if n_head_kv == n_head:
|
||||
xs = [
|
||||
@ -830,18 +829,37 @@ def combine_state_dicts_tp(state_dicts, config):
|
||||
)
|
||||
for s in state_dicts
|
||||
]
|
||||
n_head_each_rank = [
|
||||
get_dim_for_local_rank(n_head, world_size, local_rank)
|
||||
for local_rank in range(world_size)
|
||||
]
|
||||
n_head_kv_each_rank = [
|
||||
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
|
||||
for local_rank in range(world_size)
|
||||
]
|
||||
state_dict[key] = rearrange(
|
||||
torch.cat(
|
||||
[
|
||||
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
|
||||
torch.cat(
|
||||
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
|
||||
),
|
||||
torch.cat(
|
||||
[
|
||||
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
|
||||
for x in xs
|
||||
x[
|
||||
n_head_each_rank[rank] : n_head_each_rank[rank]
|
||||
+ n_head_kv_each_rank[rank]
|
||||
]
|
||||
for rank, x in enumerate(xs)
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
torch.cat(
|
||||
[
|
||||
x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
|
||||
for rank, x in enumerate(xs)
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user