137 lines
5.1 KiB
Python
137 lines
5.1 KiB
Python
from typing import List, Optional, Sequence
|
|
from pathlib import Path
|
|
|
|
import hydra
|
|
from omegaconf import OmegaConf, DictConfig
|
|
from pytorch_lightning import (
|
|
Callback,
|
|
LightningDataModule,
|
|
LightningModule,
|
|
Trainer,
|
|
seed_everything,
|
|
)
|
|
from pytorch_lightning.loggers import LightningLoggerBase
|
|
|
|
from src.utils import utils
|
|
|
|
log = utils.get_logger(__name__)
|
|
|
|
|
|
def last_modification_time(path):
|
|
"""Including files / directory 1-level below the path
|
|
"""
|
|
path = Path(path)
|
|
if path.is_file():
|
|
return path.stat().st_mtime
|
|
elif path.is_dir():
|
|
return max(child.stat().st_mtime for child in path.iterdir())
|
|
else:
|
|
return None
|
|
|
|
|
|
def train(config: DictConfig) -> Optional[float]:
|
|
"""Contains training pipeline.
|
|
Instantiates all PyTorch Lightning objects from config.
|
|
|
|
Args:
|
|
config (DictConfig): Configuration composed by Hydra.
|
|
|
|
Returns:
|
|
Optional[float]: Metric score for hyperparameter optimization.
|
|
"""
|
|
|
|
# Set seed for random number generators in pytorch, numpy and python.random
|
|
if config.get("seed"):
|
|
seed_everything(config.seed, workers=True)
|
|
|
|
# We want to add fields to config so need to call OmegaConf.set_struct
|
|
OmegaConf.set_struct(config, False)
|
|
# Init lightning model
|
|
model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, _recursive_=False)
|
|
datamodule: LightningDataModule = model._datamodule
|
|
|
|
# Init lightning callbacks
|
|
callbacks: List[Callback] = []
|
|
if "callbacks" in config:
|
|
for _, cb_conf in config.callbacks.items():
|
|
if cb_conf is not None and "_target_" in cb_conf:
|
|
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
|
callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
|
|
# Init lightning loggers
|
|
logger: List[LightningLoggerBase] = []
|
|
if "logger" in config:
|
|
for _, lg_conf in config.logger.items():
|
|
if lg_conf is not None and "_target_" in lg_conf:
|
|
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
|
logger.append(hydra.utils.instantiate(lg_conf))
|
|
|
|
ckpt_cfg = {}
|
|
if config.get('resume'):
|
|
try:
|
|
checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath)
|
|
if checkpoint_path.is_dir():
|
|
last_ckpt = checkpoint_path / 'last.ckpt'
|
|
autosave_ckpt = checkpoint_path / '.pl_auto_save.ckpt'
|
|
if not (last_ckpt.exists() or autosave_ckpt.exists()):
|
|
raise FileNotFoundError("Resume requires either last.ckpt or .pl_autosave.ckpt")
|
|
if ((not last_ckpt.exists())
|
|
or (autosave_ckpt.exists()
|
|
and last_modification_time(autosave_ckpt) > last_modification_time(last_ckpt))):
|
|
# autosave_ckpt = autosave_ckpt.replace(autosave_ckpt.with_name('.pl_auto_save_loaded.ckpt'))
|
|
checkpoint_path = autosave_ckpt
|
|
else:
|
|
checkpoint_path = last_ckpt
|
|
# DeepSpeed's checkpoint is a directory, not a file
|
|
if checkpoint_path.is_file() or checkpoint_path.is_dir():
|
|
ckpt_cfg = {'ckpt_path': str(checkpoint_path)}
|
|
else:
|
|
log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch')
|
|
except (KeyError, FileNotFoundError):
|
|
pass
|
|
|
|
# Configure ddp automatically
|
|
n_devices = config.trainer.get('devices', 1)
|
|
if isinstance(n_devices, Sequence): # trainer.devices could be [1, 3] for example
|
|
n_devices = len(n_devices)
|
|
if n_devices > 1 and config.trainer.get('strategy', None) is None:
|
|
config.trainer.strategy = dict(
|
|
_target_='pytorch_lightning.strategies.DDPStrategy',
|
|
find_unused_parameters=False,
|
|
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
|
|
)
|
|
|
|
# Init lightning trainer
|
|
log.info(f"Instantiating trainer <{config.trainer._target_}>")
|
|
trainer: Trainer = hydra.utils.instantiate(
|
|
config.trainer, callbacks=callbacks, logger=logger)
|
|
|
|
# Train the model
|
|
log.info("Starting training!")
|
|
trainer.fit(model=model, datamodule=datamodule, **ckpt_cfg)
|
|
|
|
# Evaluate model on test set, using the best model achieved during training
|
|
if config.get("test_after_training") and not config.trainer.get("fast_dev_run"):
|
|
log.info("Starting testing!")
|
|
trainer.test(model=model, datamodule=datamodule)
|
|
|
|
# Make sure everything closed properly
|
|
log.info("Finalizing!")
|
|
utils.finish(
|
|
config=config,
|
|
model=model,
|
|
datamodule=datamodule,
|
|
trainer=trainer,
|
|
callbacks=callbacks,
|
|
logger=logger,
|
|
)
|
|
|
|
# Print path to best checkpoint
|
|
if not config.trainer.get("fast_dev_run"):
|
|
log.info(f"Best model ckpt: {trainer.checkpoint_callback.best_model_path}")
|
|
|
|
# Return metric score for hyperparameter optimization
|
|
optimized_metric = config.get("optimized_metric")
|
|
if optimized_metric:
|
|
return trainer.callback_metrics[optimized_metric]
|