add slurm support

This commit is contained in:
ferdinand.mom 2024-10-30 14:25:18 +00:00
parent fdf2df8344
commit 2d198659e2
2 changed files with 7 additions and 13 deletions

View File

@ -2,7 +2,7 @@ from enum import Enum
import os
from jinja2 import Template
import subprocess
import yaml
import json
from typing import List
class Status(Enum):
@ -19,7 +19,7 @@ class Job:
def __init__(self, root_path: str, qos: str) -> None:
self.root_path = root_path
self.name = os.path.basename(root_path)
self.config = os.path.join(root_path, "config.yaml")
self.config = os.path.join(root_path, "config.json")
self.qos = qos
# Check if the status.txt file exists
@ -69,36 +69,31 @@ class Scheduler:
# Submit job to the cluster (edit jinja)
# load yaml config.yaml
with open(job.config, 'r') as file:
config = yaml.load(file, Loader=yaml.FullLoader)
config = json.load(file)
if cluster == "hf":
max_nodes = 8
elif cluster == "swiss-ai":
max_nodes = 4
else:
raise ValueError("Invalid cluster")
# Pick the right number of nodes and n_proc_per_node
world_size = config['parallelism']['pp'] * config['parallelism']['dp'] * config['parallelism']['tp']
world_size = config["distributed"]["tp_size"] * config["distributed"]["cp_size"] * config["distributed"]["pp_size"] * config["distributed"]["dp_size"]
assert world_size <= max_nodes or world_size % max_nodes == 0
nodes = max(1, world_size // max_nodes)
n_proc_per_node = min(8, world_size // nodes)
assert nodes * n_proc_per_node == world_size
target_path_hf_hub = os.path.join(os.path.basename(os.path.dirname(os.path.dirname(job.root_path))), os.path.basename(os.path.dirname(job.root_path)), os.path.basename(job.root_path))
context_bench = {
'nodes': nodes,
'n_proc_per_node': n_proc_per_node,
'root_path': job.root_path,
'target_path_hf_hub': target_path_hf_hub,
"config": job.config,
"qos": job.qos,
}
#TODO: don't hardcode the base_bench.slurm path. Should be #HOME/bench_cluster/template/base_bench.slurm
if cluster == "hf":
base_path = "/fsx/ferdinandmom/ferdinand-hf/nanotron/debug/template/base_bench.slurm"
base_path = "/fsx/ferdinandmom/ferdinand-hf/picotron/bench/template/base_bench.slurm"
else:
raise ValueError("Invalid cluster")

View File

@ -53,12 +53,11 @@ export CUBLAS_WORKSPACE_CONFIG=":4096:8"
export CUDA_DEVICE_MAX_CONNECTIONS="1"
module load cuda/12.1
huggingface-cli login --token $HUGGINGFACE_TOKEN
GIT_REPO="/fsx/ferdinandmom/ferdinand-hf/picotron/"
CMD="$GIT_REPO/run_train.py --config-path {{ config }} --logs-path {{ root_path }} --run output --slurm --nodes {{ nodes }}"
CMD="$GIT_REPO/train.py --config {{ config }}"
LAUNCHER="python"
LAUNCHER="torchrun --nproc_per_node={{ n_proc_per_node }} --nnode={{ nodes }} --node_rank=$SLURM_NODEID --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT"
# Checkout the bench_cluster branch
cd $GIT_REPO