add slurm support
This commit is contained in:
parent
fdf2df8344
commit
2d198659e2
@ -2,7 +2,7 @@ from enum import Enum
|
||||
import os
|
||||
from jinja2 import Template
|
||||
import subprocess
|
||||
import yaml
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
class Status(Enum):
|
||||
@ -19,7 +19,7 @@ class Job:
|
||||
def __init__(self, root_path: str, qos: str) -> None:
|
||||
self.root_path = root_path
|
||||
self.name = os.path.basename(root_path)
|
||||
self.config = os.path.join(root_path, "config.yaml")
|
||||
self.config = os.path.join(root_path, "config.json")
|
||||
self.qos = qos
|
||||
|
||||
# Check if the status.txt file exists
|
||||
@ -69,36 +69,31 @@ class Scheduler:
|
||||
# Submit job to the cluster (edit jinja)
|
||||
# load yaml config.yaml
|
||||
with open(job.config, 'r') as file:
|
||||
config = yaml.load(file, Loader=yaml.FullLoader)
|
||||
config = json.load(file)
|
||||
|
||||
if cluster == "hf":
|
||||
max_nodes = 8
|
||||
elif cluster == "swiss-ai":
|
||||
max_nodes = 4
|
||||
else:
|
||||
raise ValueError("Invalid cluster")
|
||||
|
||||
# Pick the right number of nodes and n_proc_per_node
|
||||
world_size = config['parallelism']['pp'] * config['parallelism']['dp'] * config['parallelism']['tp']
|
||||
world_size = config["distributed"]["tp_size"] * config["distributed"]["cp_size"] * config["distributed"]["pp_size"] * config["distributed"]["dp_size"]
|
||||
assert world_size <= max_nodes or world_size % max_nodes == 0
|
||||
nodes = max(1, world_size // max_nodes)
|
||||
n_proc_per_node = min(8, world_size // nodes)
|
||||
assert nodes * n_proc_per_node == world_size
|
||||
|
||||
target_path_hf_hub = os.path.join(os.path.basename(os.path.dirname(os.path.dirname(job.root_path))), os.path.basename(os.path.dirname(job.root_path)), os.path.basename(job.root_path))
|
||||
|
||||
context_bench = {
|
||||
'nodes': nodes,
|
||||
'n_proc_per_node': n_proc_per_node,
|
||||
'root_path': job.root_path,
|
||||
'target_path_hf_hub': target_path_hf_hub,
|
||||
"config": job.config,
|
||||
"qos": job.qos,
|
||||
}
|
||||
|
||||
#TODO: don't hardcode the base_bench.slurm path. Should be #HOME/bench_cluster/template/base_bench.slurm
|
||||
if cluster == "hf":
|
||||
base_path = "/fsx/ferdinandmom/ferdinand-hf/nanotron/debug/template/base_bench.slurm"
|
||||
base_path = "/fsx/ferdinandmom/ferdinand-hf/picotron/bench/template/base_bench.slurm"
|
||||
else:
|
||||
raise ValueError("Invalid cluster")
|
||||
|
||||
|
||||
@ -53,12 +53,11 @@ export CUBLAS_WORKSPACE_CONFIG=":4096:8"
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS="1"
|
||||
|
||||
module load cuda/12.1
|
||||
huggingface-cli login --token $HUGGINGFACE_TOKEN
|
||||
|
||||
GIT_REPO="/fsx/ferdinandmom/ferdinand-hf/picotron/"
|
||||
CMD="$GIT_REPO/run_train.py --config-path {{ config }} --logs-path {{ root_path }} --run output --slurm --nodes {{ nodes }}"
|
||||
CMD="$GIT_REPO/train.py --config {{ config }}"
|
||||
|
||||
LAUNCHER="python"
|
||||
LAUNCHER="torchrun --nproc_per_node={{ n_proc_per_node }} --nnode={{ nodes }} --node_rank=$SLURM_NODEID --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT"
|
||||
|
||||
# Checkout the bench_cluster branch
|
||||
cd $GIT_REPO
|
||||
|
||||
Loading…
Reference in New Issue
Block a user