Merge pull request #6 from huggingface/pr1

various fix
This commit is contained in:
Ferdinand Mom 2024-11-04 15:49:01 +01:00 committed by GitHub
commit cce11da2cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 49 additions and 25 deletions

View File

@ -15,7 +15,7 @@ def create_single_config(
tp: int,
cp: int,
pp: int,
dp: int,
pp_engine: str,
model_name: str,
num_hidden_layers: Optional[int],
num_attention_heads: Optional[int],
@ -24,7 +24,8 @@ def create_single_config(
mbs: int,
seq_len: int,
exp_name: str,
use_wandb: bool = False
use_wandb: bool = False,
use_fused_adam: bool = False
):
run_path = os.path.join(out_dir, exp_name)
@ -44,12 +45,13 @@ def create_single_config(
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
config_content["model"]["num_attention_heads"] = tmp_model_config.num_attention_heads if num_attention_heads is None else num_attention_heads
config_content["model"]["num_key_value_heads"] = tmp_model_config.num_key_value_heads if num_key_value_heads is None else num_key_value_heads
config_content["model"]["use_fused_adam"] = use_fused_adam
del tmp_model_config
config_content['distributed']['tp_size'] = tp
config_content['distributed']['cp_size'] = cp
config_content['distributed']['pp_size'] = pp
config_content['distributed']['dp_size'] = dp
config_content['distributed']['pp_engine'] = pp_engine
config_content['logging']['use_wandb'] = use_wandb
config_content['logging']['run_name'] = exp_name
@ -75,7 +77,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, help="number of tensor parallelism", default=1)
parser.add_argument("--cp", type=int, help="number of context parallelism", default=1)
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
parser.add_argument("--dp", type=int, help="number of data parallelism", default=1)
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
@ -85,6 +87,7 @@ if __name__ == "__main__":
parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024)
parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp")
parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging")
parser.add_argument("--use_fused_adam", action="store_true", help="Use fused adam")
args=parser.parse_args()
@ -94,6 +97,7 @@ if __name__ == "__main__":
cp=args.cp,
dp=args.dp,
pp=args.pp,
pp_engine=args.pp_engine,
model_name=args.model_name,
num_hidden_layers=args.num_hidden_layers,
num_attention_heads=args.num_attention_heads,
@ -103,4 +107,5 @@ if __name__ == "__main__":
seq_len=args.seq_len,
exp_name=args.exp_name,
use_wandb=args.use_wandb,
use_fused_adam=args.use_fused_adam
)

View File

@ -1,7 +1,7 @@
torch==2.1.2
torch==2.1.0
triton==2.1.0
numpy==1.26.4
datasets==2.19.1
transformers==4.41.1
debugpy-run
flash-attn==2.5.0
wandb

View File

@ -58,9 +58,9 @@ class DataParallel(nn.Module):
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc)
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc_fn)
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
"""

View File

@ -146,11 +146,12 @@ class Scheduler:
print(f"{'-'*10}-|-{'-'*6}")
print(f"{'Total':<10} | {total:<6}")
def submit_jobs(inp_dir, qos, nb_slurm_array, only: str = None):
def submit_jobs(inp_dir, qos, hf_token, nb_slurm_array, only: str = None):
scheduler = Scheduler(inp_dir, qos)
#TODO: batch into job arrays
env_vars = os.environ.copy()
env_vars["HUGGINGFACE_TOKEN"] = hf_token
total_jobs = len(scheduler.job_lists)
if only == "fail":
@ -212,7 +213,8 @@ if __name__ == "__main__":
parser.add_argument('--qos', type=str, help='QOS of the jobs')
parser.add_argument('--nb_slurm_array', type=int, default=0, help='Number of slurm arrays')
parser.add_argument('--only', type=str, default=None, help='Filter the jobs to submit')
parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')
args = parser.parse_args()
submit_jobs(args.inp_dir, args.qos, args.nb_slurm_array, only=args.only)
submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)

View File

@ -4,8 +4,7 @@
"cp_size": 1,
"pp_size": 1,
"dp_size": 1,
"master_addr": "localhost",
"master_port": 29500,
"pp_engine": "afab",
"backend": "nccl",
"use_cpu": false
},
@ -15,7 +14,8 @@
"num_attention_heads": 16,
"num_key_value_heads": 4,
"dtype": "bfloat16",
"use_flash_attention": true
"use_flash_attention": true,
"use_fused_adam": true
},
"training": {
"seed": 42,
@ -29,7 +29,7 @@
},
"dataset": {
"name": "roneneldan/TinyStories",
"num_workers": 4,
"num_workers": 1,
"num_proc": 4
},
"checkpoint": {

View File

@ -51,13 +51,16 @@ export HF_HOME=/fsx/$USER/.cache/huggingface
export WANDB_DIR=/fsx/$USER/.cache/wandb
export CUBLAS_WORKSPACE_CONFIG=":4096:8"
export CUDA_DEVICE_MAX_CONNECTIONS="1"
export FI_PROVIDER="efa"
module load cuda/12.1
GIT_REPO="/fsx/ferdinandmom/ferdinand-hf/picotron/"
CMD="$GIT_REPO/train.py --config {{ config }}"
LAUNCHER="torchrun --nproc_per_node={{ n_proc_per_node }} --nnode={{ nodes }} --node_rank=$SLURM_NODEID --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT"
huggingface-cli login --token $HUGGINGFACE_TOKEN
LAUNCHER="torchrun --nproc_per_node={{ n_proc_per_node }} --nnode={{ nodes }} --node_rank=$SLURM_NODEID --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} --rdzv_backend c10d --max_restarts 0 --tee 3"
# Checkout the bench_cluster branch
cd $GIT_REPO

View File

@ -8,8 +8,9 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
"""
import os
import inspect
import datetime
import json
import time
import argparse
@ -88,16 +89,15 @@ if __name__ == "__main__":
USE_WANDB = config["logging"]["use_wandb"]
TP_SIZE = config["distributed"]["tp_size"]
PP_SIZE = config["distributed"]["pp_size"]
DP_SIZE = config["distributed"]["dp_size"]
CP_SIZE = config["distributed"]["cp_size"]
PP_ENGINE = config["distributed"]["pp_engine"]
LOAD_PATH = config["checkpoint"]["load_path"]
CHECKPOINT_DIR = config["checkpoint"]["save_dir"]
CHECKPOINT_FREQ = config["checkpoint"]["save_frequency"]
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
backend = "gloo" if config["distributed"]["use_cpu"] else "nccl"
assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
@ -109,10 +109,12 @@ if __name__ == "__main__":
else:
device = torch.device("cpu")
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE)
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
dist.barrier()
set_all_seed(SEED)
model_config = AutoConfig.from_pretrained(MODEL_NAME)
@ -183,8 +185,15 @@ if __name__ == "__main__":
print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank)
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
extra_args = dict()
if config["model"]["use_fused_adam"]:
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
trained_tokens, step = 0, 0
if LOAD_PATH:
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
@ -204,7 +213,12 @@ if __name__ == "__main__":
optimizer.zero_grad()
if pgm.process_group_manager.pp_world_size > 1:
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype)
if PP_ENGINE == "afab":
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype)
elif PP_ENGINE == "1f1b":
loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype)
else:
raise ValueError(f"Invalid pipeline parallel engine: {PP_ENGINE}")
else:
loss = train_step(model, data_loader, device)