This commit is contained in:
zzhhjjj 2024-12-19 07:05:16 +00:00
parent d855aead9e
commit 7ef2344cd4
2 changed files with 5 additions and 4 deletions

View File

@ -52,7 +52,7 @@ def init_model_with_materialized_weights(model, model_config, save_dir):
initialization_manager = InitializationManager(model, model_config) initialization_manager = InitializationManager(model, model_config)
layer_names = initialization_manager.get_layer_names_in_sft_format() layer_names = initialization_manager.get_layer_names_in_sft_format()
print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers") # print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers")
if len(layer_names) == 0: if len(layer_names) == 0:
raise Exception("Some ranks has no layers. There are too many ranks and not enough layers to distribute.") raise Exception("Some ranks has no layers. There are too many ranks and not enough layers to distribute.")

View File

@ -143,9 +143,10 @@ if __name__ == "__main__":
if pgm.process_group_manager.global_rank == 0: if pgm.process_group_manager.global_rank == 0:
print(f"rank {pgm.process_group_manager.global_rank}: Creating model config") print(f"rank {pgm.process_group_manager.global_rank}: Creating model config")
model_config = AutoConfig.from_pretrained(config["model"]["name"]) model_config = AutoConfig.from_pretrained(config["model"]["name"])
model_config.num_hidden_layers = config["model"]["num_hidden_layers"] # twist the model structure if specified in the config file
model_config.num_attention_heads = config["model"]["num_attention_heads"] model_config.num_hidden_layers = model_config.num_hidden_layers if "num_hidden_layers" not in config["model"] else config["model"]["num_hidden_layers"]
model_config.num_key_value_heads = config["model"]["num_key_value_heads"] model_config.num_attention_heads = model_config.num_attention_heads if "num_attention_heads" not in config["model"] else config["model"]["num_attention_heads"]
model_config.num_key_value_heads = model_config.num_key_value_heads if "num_key_value_heads" not in config["model"] else config["model"]["num_key_value_heads"]
model_config.max_position_embeddings = config["training"]["seq_length"] model_config.max_position_embeddings = config["training"]["seq_length"]
objects = [model_config] objects = [model_config]
else: else: