2024-10-16 23:58:35 +08:00
""" Training script for LLaMA model.
2024-10-18 13:13:44 +08:00
torchrun - - nproc_per_node 1 - - master_addr localhost - - master_port 25500 train . py - - use_wandb
2024-10-28 15:46:23 +08:00
torchrun - - nproc_per_node 2 - - master_addr localhost - - master_port 25500 train . py - - dp_size 2 - - use_wandb
2024-10-28 13:19:59 +08:00
torchrun - - nproc_per_node 4 - - master_addr localhost - - master_port 25500 train . py - - tp_size 2 - - pp_size 2 - - use_wandb
torchrun - - nproc_per_node 4 - - master_addr localhost - - master_port 25500 train . py - - tp_size 2 - - pp_size 2 - - load_path ckpt / 150
2024-10-27 10:22:36 +08:00
torchrun - - nproc_per_node 8 - - master_addr localhost - - master_port 25500 train . py - - tp_size 2 - - dp_size 2 - - pp_size 2 - - use_wandb
CUDA_DEVICE_MAX_CONNECTIONS = 1 debugpy - run - p 5678 - m torch . distributed . run - - - - nproc_per_node = 4 - - nnodes = 1 - - rdzv_backend = c10d - - rdzv_endpoint = localhost : 29400 train . py - - tp_size 2 - - pp_size 2
2024-10-16 23:58:35 +08:00
CUDA_DEVICE_MAX_CONNECTIONS = 1 torchrun - - nproc_per_node = 4 - - nnodes = 1 - - rdzv_backend = c10d - - rdzv_endpoint = localhost : 29400 - - max_restarts = 0 - - tee = 3 train . py
2024-09-24 05:14:48 +08:00
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
2024-10-16 23:58:35 +08:00
"""
2024-09-19 22:06:46 +08:00
import os
2024-10-23 08:38:27 +08:00
import time
import argparse
2024-10-18 22:33:46 +08:00
import numpy as np
2024-09-25 21:17:05 +08:00
import torch . nn . functional as F
2024-09-19 22:06:46 +08:00
import torch , torch . distributed as dist
from torch . optim import AdamW
2024-10-10 23:12:14 +08:00
from transformers import AutoConfig
from transformers import AutoTokenizer
from torch . utils . data import DataLoader , DistributedSampler
2024-10-18 22:33:46 +08:00
from datasets import load_dataset , Features , Sequence , Value
2024-10-23 08:38:27 +08:00
from functools import partial
2024-10-16 23:58:35 +08:00
from datasets import Features , Sequence , Value
import numpy as np
2024-10-18 13:13:44 +08:00
from src . parallel . tensor_parallel . tensor_parallel import TensorParallel
2024-10-16 23:58:35 +08:00
import src . distributed . process_group_manager as pgm
2024-10-28 13:19:59 +08:00
from utils import set_all_seed , print , to_readable_format , save_checkpoint , load_checkpoint
2024-10-16 23:58:35 +08:00
from src . distributed . process_group_manager import setup_process_group_manager
from src . parallel . pipeline_parallel import train_step_pipeline_1f1b , train_step_pipeline_afab , PipelineParallel
from src . parallel . data_parallel . data_parallel_bucket import DataParallel
2024-10-18 22:59:39 +08:00
from src . parallel . context_parallel import ContextParallel
2024-10-10 23:08:23 +08:00
from model import Llama
2024-09-25 22:19:16 +08:00
import wandb
2024-10-23 07:38:44 +08:00
from src . distributed . distributed_primtives import all_reduce_loss_across_dp_ranks
2024-09-25 22:19:16 +08:00
2024-10-10 23:12:14 +08:00
class MicroBatchDataLoader ( DataLoader ) :
2024-10-28 15:46:23 +08:00
def __init__ ( self , local_batch_size , micro_batch_size , seq_length , dataset_name , tokenizer_name , num_workers , num_proc , grad_acc = 1 , split = " train " , num_samples = None ) :
self . global_batch_size = local_batch_size * pgm . process_group_manager . dp_world_size
2024-10-18 22:33:46 +08:00
self . micro_batch_size = micro_batch_size
self . seq_length = seq_length
2024-10-28 15:46:23 +08:00
self . local_batch_size = local_batch_size
2024-10-10 23:12:14 +08:00
self . num_local_micro_batches = self . local_batch_size / / self . micro_batch_size
self . num_global_micro_batches = self . global_batch_size / / self . micro_batch_size
2024-10-18 22:33:46 +08:00
self . grad_acc = grad_acc
2024-10-10 23:12:14 +08:00
2024-10-15 20:43:28 +08:00
self . seq_length_per_gpu = seq_length / / pgm . process_group_manager . cp_world_size
2024-10-10 23:12:14 +08:00
self . tokenizer = AutoTokenizer . from_pretrained ( tokenizer_name )
self . dataset = load_dataset ( dataset_name , split = split )
2024-10-18 22:33:46 +08:00
if num_samples :
self . dataset = self . dataset . select ( range ( min ( num_samples , len ( self . dataset ) ) ) )
2024-10-10 23:12:14 +08:00
dist . barrier ( )
2024-10-18 22:33:46 +08:00
# Tokenize and chunk the dataset
self . tokenized_dataset = self . tokenize_dataset ( self . dataset , " text " , self . seq_length , num_proc )
self . sampler = DistributedSampler (
self . tokenized_dataset ,
num_replicas = pgm . process_group_manager . dp_world_size ,
rank = pgm . process_group_manager . dp_rank ,
shuffle = False
)
2024-10-10 23:12:14 +08:00
2024-10-18 22:33:46 +08:00
super ( ) . __init__ (
self . tokenized_dataset ,
batch_size = micro_batch_size if pgm . process_group_manager . pp_world_size > 1 else self . local_batch_size , # in PP we split a single batch into multiple micro-batches
collate_fn = self . collate_batch ,
pin_memory = True ,
num_workers = num_workers ,
sampler = self . sampler ,
shuffle = False
)
2024-10-23 08:38:27 +08:00
@staticmethod
def tokenizer_group_text ( examples , tokenizer , sequence_length ) :
""" Tokenize a list of texts and group them in chunks of sequence_length + 1 """
tokenized_text_batch = tokenizer . batch_encode_plus (
examples ,
return_attention_mask = False ,
return_token_type_ids = False ,
return_tensors = ' np '
)
concatenated_tokens = { ' input_ids ' : np . concatenate ( tokenized_text_batch [ ' input_ids ' ] ) }
total_length = len ( concatenated_tokens [ ' input_ids ' ] )
if total_length > = sequence_length + 1 :
total_length = ( ( total_length - 1 ) / / sequence_length ) * sequence_length + 1
result = {
' input_ids ' : [
concatenated_tokens [ ' input_ids ' ] [ i : i + sequence_length + 1 ]
for i in range ( 0 , total_length - sequence_length , sequence_length )
]
}
return result
2024-10-18 22:33:46 +08:00
def tokenize_dataset ( self , dataset , text_column_name , sequence_length , num_proc ) :
2024-10-23 08:38:27 +08:00
""" Tokenize the dataset and group texts in chunks of sequence_length + 1 """
# Create a partial function with fixed arguments
tokenizer_func = partial (
self . tokenizer_group_text ,
tokenizer = self . tokenizer ,
sequence_length = sequence_length
)
2024-10-18 22:33:46 +08:00
tokenized_dataset = dataset . map (
2024-10-23 08:38:27 +08:00
tokenizer_func ,
2024-10-18 22:33:46 +08:00
input_columns = text_column_name ,
remove_columns = dataset . column_names ,
2024-10-23 08:38:27 +08:00
features = Features ( {
" input_ids " : Sequence ( feature = Value ( dtype = " int64 " ) , length = sequence_length + 1 )
} ) ,
2024-10-18 22:33:46 +08:00
batched = True ,
2024-10-23 08:38:27 +08:00
num_proc = num_proc ,
2024-10-18 22:33:46 +08:00
load_from_cache_file = True ,
desc = f " Grouping texts in chunks of { sequence_length + 1 } " ,
)
2024-10-10 23:12:14 +08:00
2024-10-18 22:33:46 +08:00
return tokenized_dataset
2024-10-10 23:12:14 +08:00
2024-10-18 22:33:46 +08:00
def collate_batch ( self , batch ) :
batch_input_ids = torch . stack ( [ torch . tensor ( item [ ' input_ids ' ] ) for item in batch ] )
batch_size = batch_input_ids . size ( 0 )
2024-10-15 20:43:28 +08:00
start_idx = pgm . process_group_manager . cp_rank * self . seq_length_per_gpu
end_idx = start_idx + self . seq_length_per_gpu
input_ids = batch_input_ids [ : , start_idx : end_idx ] . contiguous ( )
target_ids = batch_input_ids [ : , start_idx + 1 : end_idx + 1 ] . contiguous ( )
2024-10-18 22:33:46 +08:00
position_ids = torch . arange ( start_idx , end_idx , dtype = torch . long ) . unsqueeze ( 0 ) . expand ( batch_size , - 1 ) . contiguous ( )
2024-10-15 20:43:28 +08:00
local_attn_mask = torch . tril ( torch . ones ( ( self . seq_length_per_gpu , self . seq_length_per_gpu ) , dtype = torch . bool ) )
attn_mask = local_attn_mask . unsqueeze ( 0 ) . expand ( batch_size , - 1 , - 1 ) . contiguous ( )
return {
" input_ids " : input_ids ,
" target_ids " : target_ids ,
2024-10-18 22:33:46 +08:00
" position_ids " : position_ids ,
2024-10-15 20:43:28 +08:00
" attn_mask " : attn_mask ,
" hidden_states " : None
}
2024-10-18 23:25:53 +08:00
2024-10-18 23:05:01 +08:00
def __iter__ ( self ) :
if self . _iterator is None :
self . _iterator = super ( ) . __iter__ ( )
return self
2024-09-25 21:17:05 +08:00
2024-10-18 23:05:01 +08:00
def __next__ ( self ) :
if self . _iterator is None :
self . _iterator = super ( ) . __iter__ ( )
try :
batch = next ( self . _iterator )
except StopIteration :
self . _iterator = None
raise StopIteration
return batch
2024-10-18 22:33:46 +08:00
2024-10-16 23:58:35 +08:00
def train_step ( model , data_loader , device ) :
acc_loss = 0.0
2024-10-28 15:46:23 +08:00
ddp = pgm . process_group_manager . dp_world_size > 1
2024-10-18 13:13:44 +08:00
for i in range ( data_loader . grad_acc ) :
2024-10-28 15:46:23 +08:00
# get the next batch
batch = next ( data_loader )
input_ids = batch [ " input_ids " ] . to ( device )
target_ids = batch [ " target_ids " ] . to ( device )
# disable gradient synchronization for all but the last micro-batch
if ddp :
model . require_backward_grad_sync = ( i == data_loader . grad_acc - 1 )
2024-10-18 13:13:44 +08:00
outputs = model ( input_ids = input_ids )
2024-09-25 21:17:05 +08:00
2024-10-18 13:13:44 +08:00
# compute the loss
2024-10-15 20:43:28 +08:00
batch_size , seq_len = input_ids . shape
2024-10-18 13:13:44 +08:00
target_ids = target_ids . reshape ( - 1 )
outputs = outputs . view ( seq_len * batch_size , - 1 )
2024-10-28 15:46:23 +08:00
loss = F . cross_entropy ( outputs , target_ids , reduction = ' mean ' ) / data_loader . grad_acc
2024-10-18 13:13:44 +08:00
2024-09-25 21:17:05 +08:00
loss . backward ( )
2024-09-25 20:36:22 +08:00
2024-10-18 13:13:44 +08:00
acc_loss + = loss . item ( )
2024-09-19 22:06:46 +08:00
2024-10-16 23:58:35 +08:00
return acc_loss
2024-09-19 22:06:46 +08:00
if __name__ == " __main__ " :
2024-09-23 18:28:01 +08:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --tp_size " , type = int , default = 1 )
2024-09-26 21:45:53 +08:00
parser . add_argument ( " --cp_size " , type = int , default = 1 )
2024-09-23 18:28:01 +08:00
parser . add_argument ( " --pp_size " , type = int , default = 1 )
parser . add_argument ( " --dp_size " , type = int , default = 1 )
2024-09-25 22:19:16 +08:00
parser . add_argument ( " --use_wandb " , action = " store_true " , default = False )
2024-09-26 18:27:20 +08:00
parser . add_argument ( " --use_cpu " , action = " store_true " , default = False )
parser . add_argument ( " --master_addr " , type = str , default = " localhost " )
parser . add_argument ( " --master_port " , type = int , default = 29500 )
2024-10-28 13:19:59 +08:00
parser . add_argument ( " --load_path " , type = str , default = " " , help = " Path to load the model from " )
parser . add_argument ( " --ckpt_dir " , type = str , default = " ckpt " , help = " Directory to save checkpoints " )
2024-10-28 15:46:23 +08:00
parser . add_argument ( " --ckpt_freq " , type = int , default = 300 , help = " Frequency to save checkpoints " )
2024-09-23 18:28:01 +08:00
args = parser . parse_args ( )
2024-09-26 18:27:20 +08:00
os . environ [ " OMP_NUM_THREADS " ] = " 1 "
2024-09-19 22:06:46 +08:00
os . environ [ " TOKENIZERS_PARALLELISM " ] = " false "
2024-10-23 07:38:44 +08:00
dtype = torch . bfloat16 if torch . cuda . is_available ( ) and torch . cuda . is_bf16_supported ( ) else torch . float32
os . environ [ " DTYPE " ] = " bfloat16 " if dtype == torch . bfloat16 else " float32 "
os . environ [ " FLASH_ATTEN " ] = " 1 " # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.float16!
assert ( dtype == torch . bfloat16 and os . getenv ( " FLASH_ATTEN " ) == " 1 " ) or os . getenv ( " FLASH_ATTEN " ) != " 1 " , " Kernel operations requires dtype=torch.bfloat16 "
2024-09-26 18:27:20 +08:00
local_rank = int ( os . environ [ " LOCAL_RANK " ] )
world_size = int ( os . environ [ " WORLD_SIZE " ] )
host = os . environ [ " MASTER_ADDR " ]
port = int ( os . environ [ " MASTER_PORT " ] )
2024-10-16 23:58:35 +08:00
## hyperparameters
2024-10-28 15:46:23 +08:00
SEQ_LEN , LOCAL_BATCH_SIZE , MICRO_BATCH_SIZE , LEARNING_RATE , NUM_SAMPLES , MAX_TOKENS , SEED = 1024 , 64 , 32 , 3e-4 , 400000 , None , 42
2024-10-27 10:22:36 +08:00
total_train_steps = 200
2024-10-28 15:46:23 +08:00
grad_acc = 2
2024-10-27 10:22:36 +08:00
2024-10-15 20:43:28 +08:00
assert SEQ_LEN % args . cp_size == 0 , " SEQ_LEN must be divisible by cp_size for Context Parallelism "
2024-09-26 18:27:20 +08:00
backend = " gloo " if args . use_cpu else " nccl "
if backend == " nccl " :
torch . cuda . set_device ( local_rank )
device = torch . device ( " cuda " , local_rank )
else :
device = torch . device ( " cpu " )
dist . init_process_group ( rank = local_rank , world_size = world_size , backend = backend , init_method = f " tcp:// { host } : { port } " )
2024-09-26 21:45:53 +08:00
setup_process_group_manager ( tp_size = args . tp_size , cp_size = args . cp_size , pp_size = args . pp_size , dp_size = args . dp_size )
2024-09-23 18:28:01 +08:00
2024-10-14 17:26:31 +08:00
# if pgm.process_group_manager.global_rank == 0:
# display_4D_parallelism_grid()
2024-10-28 15:46:23 +08:00
tokens_per_step = LOCAL_BATCH_SIZE * SEQ_LEN * grad_acc * args . dp_size
2024-10-27 10:22:36 +08:00
if pgm . process_group_manager . global_rank == 0 :
print ( " Tokens per step: " , to_readable_format ( tokens_per_step ) )
2024-09-25 22:19:16 +08:00
set_all_seed ( SEED )
2024-10-18 22:33:46 +08:00
2024-09-25 22:19:16 +08:00
dataset_name = " roneneldan/TinyStories "
2024-10-16 23:58:35 +08:00
model_name = " HuggingFaceTB/SmolLM-360M-Instruct "
2024-10-23 08:38:27 +08:00
# model_name = "meta-llama/Llama-2-7b-hf"
2024-09-25 20:36:22 +08:00
config = AutoConfig . from_pretrained ( model_name )
2024-10-23 08:38:27 +08:00
config . num_hidden_layers = 16
2024-10-18 13:13:44 +08:00
config . num_attention_heads = 16
config . num_key_value_heads = 4
2024-10-16 23:58:35 +08:00
2024-10-27 10:22:36 +08:00
start_time = time . time ( )
2024-10-18 22:59:39 +08:00
model = Llama ( config = config )
2024-10-27 10:22:36 +08:00
print ( " init model time: " , time . time ( ) - start_time )
2024-09-25 22:19:16 +08:00
2024-10-27 10:22:36 +08:00
wandb_rank = pgm . process_group_manager . tp_rank == 0 and pgm . process_group_manager . dp_rank == 0 and pgm . process_group_manager . pp_is_last_stage
if wandb_rank and args . use_wandb :
2024-09-25 22:19:16 +08:00
wandb . init (
project = " picotron " ,
2024-10-27 10:22:36 +08:00
name = f " test_convergence_GBS_ { tokens_per_step } _ { pgm . process_group_manager } " ,
2024-09-25 22:19:16 +08:00
config = {
" tensor_parallel_size " : pgm . process_group_manager . tp_size ,
" pipeline_parallel_size " : pgm . process_group_manager . pp_size ,
2024-09-26 18:27:20 +08:00
" data_parallel_size " : pgm . process_group_manager . dp_size ,
2024-09-25 22:19:16 +08:00
" model " : model_name ,
" dataset " : dataset_name ,
" max_tokens " : MAX_TOKENS ,
" learning_rate " : LEARNING_RATE ,
" seed " : SEED ,
" micro_batch_size " : MICRO_BATCH_SIZE ,
2024-10-28 15:46:23 +08:00
" global_batch_size " : LOCAL_BATCH_SIZE * args . dp_size ,
2024-09-25 22:19:16 +08:00
} ,
)
2024-10-14 17:26:31 +08:00
2024-10-27 10:22:36 +08:00
start_time = time . time ( )
2024-10-18 13:13:44 +08:00
if pgm . process_group_manager . tp_world_size > 1 :
2024-10-23 03:50:23 +08:00
TensorParallel ( model )
2024-09-25 20:36:22 +08:00
2024-10-18 13:13:44 +08:00
# if pgm.process_group_manager.cp_size > 1:
2024-10-18 22:59:39 +08:00
#TODO: do at the very end when we have fix convergence issue
# model = ContextParallel(model, config)
2024-09-26 21:45:53 +08:00
2024-09-25 21:33:20 +08:00
if pgm . process_group_manager . pp_world_size > 1 :
2024-10-16 23:58:35 +08:00
model = PipelineParallel ( model , config )
2024-09-25 21:17:05 +08:00
2024-09-25 21:33:20 +08:00
if pgm . process_group_manager . dp_world_size > 1 :
2024-10-18 22:59:39 +08:00
model = DataParallel ( model )
2024-10-27 10:22:36 +08:00
print ( " init parallel time: " , time . time ( ) - start_time )
start_time = time . time ( )
2024-10-23 06:38:29 +08:00
model . to ( dtype ) . to ( device )
2024-09-25 20:36:22 +08:00
model . train ( )
2024-10-27 10:22:36 +08:00
print ( " model to device time: " , time . time ( ) - start_time )
2024-09-25 20:36:22 +08:00
2024-10-27 10:22:36 +08:00
start_time = time . time ( )
2024-10-28 15:46:23 +08:00
data_loader = MicroBatchDataLoader ( local_batch_size = LOCAL_BATCH_SIZE , micro_batch_size = MICRO_BATCH_SIZE , seq_length = SEQ_LEN , dataset_name = dataset_name , tokenizer_name = model_name , grad_acc = grad_acc , num_workers = 4 , num_proc = 4 , num_samples = NUM_SAMPLES )
2024-10-27 10:22:36 +08:00
print ( " init dataloader time: " , time . time ( ) - start_time )
2024-10-15 20:43:28 +08:00
tensor_shapes = ( data_loader . micro_batch_size , data_loader . seq_length_per_gpu , config . hidden_size )
2024-09-19 22:06:46 +08:00
optimizer = AdamW ( model . parameters ( ) , lr = LEARNING_RATE )
2024-09-25 20:36:22 +08:00
2024-09-19 22:06:46 +08:00
trained_tokens , step = 0 , 0
2024-10-28 13:19:59 +08:00
if args . load_path :
step , trained_tokens = load_checkpoint ( model , optimizer , args . load_path )
checkpoint_dir = args . ckpt_dir
checkpoint_freq = args . ckpt_freq
2024-09-25 20:36:22 +08:00
dist . barrier ( )
2024-10-15 21:06:17 +08:00
2024-09-25 21:33:20 +08:00
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
2024-09-26 21:45:53 +08:00
#TODO: Check convergence
#TODO: Try multi-nodes
2024-09-25 21:17:05 +08:00
#TODO: Add activation checkpointing
#TODO: add gradient accumulation
2024-09-24 05:14:48 +08:00
2024-10-27 10:22:36 +08:00
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS :
2024-10-18 22:33:46 +08:00
#TODO: Add epoch support
# data_loader.set_epoch(step)
2024-10-23 08:38:27 +08:00
step_start_time = time . time ( )
2024-09-19 22:06:46 +08:00
optimizer . zero_grad ( )
2024-09-25 21:17:05 +08:00
2024-09-25 21:33:20 +08:00
if pgm . process_group_manager . pp_world_size > 1 :
2024-09-25 21:17:05 +08:00
loss = train_step_pipeline_afab ( model , data_loader , tensor_shapes , device )
else :
loss = train_step ( model , data_loader , device )
2024-10-16 23:58:35 +08:00
2024-10-23 07:38:44 +08:00
loss = all_reduce_loss_across_dp_ranks ( loss , device )
2024-10-18 23:25:53 +08:00
2024-09-19 22:06:46 +08:00
optimizer . step ( )
trained_tokens + = tokens_per_step
step + = 1
2024-09-24 21:43:22 +08:00
2024-10-17 00:48:55 +08:00
# In DDP implementation I need to reset the gradient buffers
if hasattr ( model , ' reset ' ) :
model . reset ( )
2024-10-23 08:38:27 +08:00
step_duration = time . time ( ) - step_start_time
2024-10-17 00:48:55 +08:00
2024-10-27 10:22:36 +08:00
if wandb_rank :
2024-10-16 23:58:35 +08:00
print ( f " [rank { pgm . process_group_manager . global_rank } ] Step: { step } , Loss: { loss : .4f } , "
2024-10-23 08:38:27 +08:00
f " Global batch size: { to_readable_format ( tokens_per_step ) } , "
f " Tokens/s: { to_readable_format ( tokens_per_step / step_duration ) } , "
f " Tokens/s/GPU: { to_readable_format ( tokens_per_step / step_duration / world_size ) } , "
2024-10-27 10:22:36 +08:00
f " Tokens: { to_readable_format ( trained_tokens ) } { ( ' / ' + to_readable_format ( MAX_TOKENS ) ) if MAX_TOKENS else ' ' } , "
f " Memory usage: { torch . cuda . memory_reserved ( ) / 1e9 : .2f } GB "
2024-10-23 08:38:27 +08:00
)
2024-09-25 22:19:16 +08:00
2024-10-27 10:22:36 +08:00
if args . use_wandb :
wandb . log ( { " loss " : loss , " tokens_per_step " : tokens_per_step , " tokens_per_second " : tokens_per_step / step_duration , \
" memory_usage " : torch . cuda . memory_reserved ( ) / 1e9 , " trained_tokens " : trained_tokens } )
2024-10-28 13:19:59 +08:00
if step % checkpoint_freq == 0 :
save_checkpoint ( model , optimizer , step , trained_tokens , checkpoint_dir + f " / { step } " )
2024-10-28 15:46:23 +08:00
if step > = total_train_steps :
2024-10-27 10:22:36 +08:00
break
2024-09-25 22:19:16 +08:00
2024-10-27 10:22:36 +08:00
if wandb_rank and args . use_wandb :
2024-09-25 22:19:16 +08:00
wandb . finish ( )
2024-09-24 21:43:22 +08:00
dist . destroy_process_group ( )