2024-10-29 23:44:35 +08:00
from enum import Enum
import os
from jinja2 import Template
import subprocess
2024-10-30 22:25:18 +08:00
import json
2024-10-29 23:44:35 +08:00
from typing import List
class Status ( Enum ) :
# INIT -> PENDING -> [RUNNING | FAIL | TIMEOUT OOM] -> COMPLETED
INIT = " init " # Job is created
PENDING = " pending " # Job is waiting for ressources
RUNNING = " running " # Job is running
FAIL = " fail " # Job failed
OOM = " oom " # Job failed due to out of memory (expected behavior)
TIMEOUT = " timeout " # Job failed due to timeout
COMPLETED = " completed " # Job is completed
class Job :
def __init__ ( self , root_path : str , qos : str ) - > None :
self . root_path = root_path
self . name = os . path . basename ( root_path )
2024-10-30 22:25:18 +08:00
self . config = os . path . join ( root_path , " config.json " )
2024-10-29 23:44:35 +08:00
self . qos = qos
# Check if the status.txt file exists
status_file_path = os . path . join ( self . root_path , " status.txt " )
if not os . path . exists ( status_file_path ) :
# Create the status.txt file with INIT status
with open ( status_file_path , ' w ' ) as f :
f . write ( Status . INIT . value )
self . status = self . get_status ( )
def get_status ( self ) - > Status :
"""
Read the status of the job from ` status . txt ` and return it
"""
is_existing = lambda value_to_check : any ( value . value == value_to_check for value in Status . __members__ . values ( ) )
status_file_path = os . path . join ( self . root_path , " status.txt " )
with open ( status_file_path , ' r ' ) as f :
status = f . read ( )
if not is_existing ( status ) :
raise ValueError ( " Invalid status " )
return Status ( status )
def set_status ( self , status : Status ) - > Status :
"""
Update the status of the job in ` status . txt ` and return the new status
"""
status_file_path = os . path . join ( self . root_path , " status.txt " )
with open ( status_file_path , ' w ' ) as f :
f . write ( status . value )
return status
class Scheduler :
def __init__ ( self , inp_dir : str , qos : str ) - > None :
jobs_directory_paths = [ os . path . abspath ( root ) for root , dirs , _ in os . walk ( inp_dir ) if not dirs ]
jobs_directory_paths = [ job_path . replace ( " /profiler " , " " ) if " profiler " in job_path else job_path for job_path in jobs_directory_paths ]
self . job_lists = [ Job ( job_path , qos ) for job_path in jobs_directory_paths ]
def keep_only_jobs ( self , status : Status ) :
return [ job for job in self . job_lists if job . status == status ]
def filter_out_jobs ( self , status : Status ) :
return [ job for job in self . job_lists if job . status != status ]
2024-10-30 22:49:48 +08:00
def create_slurm_script ( self , job : Job ) :
2024-10-29 23:44:35 +08:00
# Submit job to the cluster (edit jinja)
# load yaml config.yaml
with open ( job . config , ' r ' ) as file :
2024-10-30 22:25:18 +08:00
config = json . load ( file )
2024-10-29 23:44:35 +08:00
2024-10-30 22:49:48 +08:00
max_gpu_per_node = 8
2024-10-29 23:44:35 +08:00
# Pick the right number of nodes and n_proc_per_node
2024-10-30 22:25:18 +08:00
world_size = config [ " distributed " ] [ " tp_size " ] * config [ " distributed " ] [ " cp_size " ] * config [ " distributed " ] [ " pp_size " ] * config [ " distributed " ] [ " dp_size " ]
2024-10-30 22:49:48 +08:00
assert world_size < = max_gpu_per_node or world_size % max_gpu_per_node == 0
nodes = max ( 1 , world_size / / max_gpu_per_node )
n_proc_per_node = min ( max_gpu_per_node , world_size / / nodes )
2024-10-29 23:44:35 +08:00
assert nodes * n_proc_per_node == world_size
context_bench = {
' nodes ' : nodes ,
' n_proc_per_node ' : n_proc_per_node ,
' root_path ' : job . root_path ,
" config " : job . config ,
" qos " : job . qos ,
}
2024-10-30 22:49:48 +08:00
base_path = os . path . join ( os . getcwd ( ) , " template/base_job.slurm " )
2024-10-29 23:44:35 +08:00
with open ( base_path , ' r ' ) as file :
2024-10-30 22:29:22 +08:00
base_job_file = file . read ( )
2024-10-29 23:44:35 +08:00
2024-10-30 22:29:22 +08:00
base_job_template = Template ( base_job_file )
2024-10-29 23:44:35 +08:00
# Write the rendered script to a new file located at the job root_path
2024-10-30 22:29:22 +08:00
output_file_path = os . path . join ( job . root_path , " job.slurm " )
2024-10-29 23:44:35 +08:00
with open ( output_file_path , ' w ' ) as file :
2024-10-30 22:29:22 +08:00
file . write ( base_job_template . render ( context_bench ) )
2024-10-29 23:44:35 +08:00
print ( f " Slurm script created at { output_file_path } " )
def launch_dependency ( self , job_array : List [ Job ] , env_vars ) :
prev_job_id = None
for job in job_array :
if prev_job_id is None :
2024-10-30 22:49:48 +08:00
result = subprocess . run ( [ " sbatch " , ' --parsable ' , os . path . join ( job . root_path , " job.slurm " ) ] , env = env_vars , capture_output = True , text = True )
2024-10-29 23:44:35 +08:00
else :
2024-10-30 22:49:48 +08:00
result = subprocess . run ( [ " sbatch " , ' --parsable ' , ' --dependency=afterany: ' + prev_job_id , os . path . join ( job . root_path , " job.slurm " ) ] , env = env_vars , capture_output = True , text = True )
2024-10-29 23:44:35 +08:00
job . set_status ( Status . PENDING )
prev_job_id = result . stdout . strip ( )
def check_status ( self ) :
# find all status files using self.jobs_directory_paths
status_files = [ os . path . join ( job . root_path , " status.txt " ) for job in self . job_lists ]
status_counts = {
" init " : 0 ,
" pending " : 0 ,
" running " : 0 ,
" fail " : 0 ,
" oom " : 0 ,
" timeout " : 0 ,
" completed " : 0
}
for status_file in status_files :
with open ( status_file , ' r ' ) as f :
status = f . read ( ) . strip ( )
if status in status_counts :
status_counts [ status ] + = 1
else :
raise ValueError ( f " Invalid status: { status } " )
total = sum ( status_counts . values ( ) )
# Print the status counts in a formatted table
print ( f " { ' Status ' : <10 } | { ' Count ' : <6 } " )
print ( f " { ' - ' * 10 } -|- { ' - ' * 6 } " )
for status , count in status_counts . items ( ) :
print ( f " { status . capitalize ( ) : <10 } | { count : <6 } " )
print ( f " { ' - ' * 10 } -|- { ' - ' * 6 } " )
print ( f " { ' Total ' : <10 } | { total : <6 } " )
2024-11-04 22:39:12 +08:00
def submit_jobs ( inp_dir , qos , hf_token , nb_slurm_array , only : str = None ) :
2024-10-29 23:44:35 +08:00
scheduler = Scheduler ( inp_dir , qos )
#TODO: batch into job arrays
env_vars = os . environ . copy ( )
2024-11-04 22:39:12 +08:00
env_vars [ " HUGGINGFACE_TOKEN " ] = hf_token
2024-10-29 23:44:35 +08:00
total_jobs = len ( scheduler . job_lists )
if only == " fail " :
scheduler . job_lists = scheduler . keep_only_jobs ( Status . FAIL )
elif only == " pending " :
scheduler . job_lists = scheduler . keep_only_jobs ( Status . PENDING )
elif only == " timeout " :
scheduler . job_lists = scheduler . keep_only_jobs ( Status . TIMEOUT )
elif only == " running " :
scheduler . job_lists = scheduler . keep_only_jobs ( Status . RUNNING )
if only is not None :
filtered_jobs = len ( scheduler . job_lists )
if filtered_jobs == 0 :
print ( f " No ' { only } ' jobs to resubmit " )
return
print ( f " Only { filtered_jobs } / { total_jobs } jobs with status ' { only } ' will be resubmitted " )
scheduler . job_lists = scheduler . filter_out_jobs ( Status . COMPLETED )
if nb_slurm_array > 0 :
# Use job dependecies
# Distribute the jobs into the arrays
base_jobs_per_array = len ( scheduler . job_lists ) / / nb_slurm_array
extra_jobs = len ( scheduler . job_lists ) % nb_slurm_array
distribution = [ base_jobs_per_array ] * nb_slurm_array
for i in range ( extra_jobs ) :
distribution [ i ] + = 1
start = 0
for i , nb_jobs in enumerate ( distribution ) :
previous_job_id = None
end = start + nb_jobs
job_array = scheduler . job_lists [ start : end ]
print ( f " Launching job Dependency array { i + 1 } with { nb_jobs } jobs " )
for job in job_array :
2024-10-30 22:49:48 +08:00
scheduler . create_slurm_script ( job )
2024-10-29 23:44:35 +08:00
scheduler . launch_dependency ( job_array , env_vars )
start = end
else :
# Don't use job dependecies
for job in scheduler . job_lists :
2024-10-30 22:49:48 +08:00
scheduler . create_slurm_script ( job )
print ( os . path . join ( job . root_path , " job.slurm " ) )
subprocess . run ( [ " sbatch " , os . path . join ( job . root_path , " job.slurm " ) ] , env = env_vars )
2024-10-29 23:44:35 +08:00
job . set_status ( Status . PENDING )
if __name__ == " __main__ " :
import argparse
parser = argparse . ArgumentParser ( description = ' Submit jobs to the cluster ' )
parser . add_argument ( ' --inp_dir ' , type = str , help = ' Input directory containing the jobs ' )
parser . add_argument ( ' --qos ' , type = str , help = ' QOS of the jobs ' )
parser . add_argument ( ' --nb_slurm_array ' , type = int , default = 0 , help = ' Number of slurm arrays ' )
parser . add_argument ( ' --only ' , type = str , default = None , help = ' Filter the jobs to submit ' )
2024-11-04 22:39:12 +08:00
parser . add_argument ( ' --hf_token ' , type = str , required = True , help = ' Huggingface token ' )
2024-10-29 23:44:35 +08:00
args = parser . parse_args ( )
2024-12-02 03:39:16 +08:00
#TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...."
2024-11-04 22:39:12 +08:00
submit_jobs ( args . inp_dir , args . qos , args . hf_token , args . nb_slurm_array , only = args . only )