From bcb45d906d5f378a30461d513728cad34fc647ce Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 4 Mar 2020 23:02:19 -0500 Subject: [PATCH] proper checkpoint implementation (#1043) * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * name formatting * version * testing * add test * fix test * Update model_checkpoint.py * doctests * pylint * tests * debug * debug * enabled early stopping/checkpooiunt even without val step * fix MNIST download (#1044) * fix MNIST download * simple * name formatting * version * testing * add test * fix test * doctests * tests * debug * debug * rebased 1041 * rebased 1041 * tests * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 Co-authored-by: Jirka Borovec --- CHANGELOG.md | 1 + .../callbacks/model_checkpoint.py | 172 ++++++++++-------- pytorch_lightning/core/lightning.py | 14 +- pytorch_lightning/loggers/base.py | 4 + pytorch_lightning/trainer/callback_config.py | 8 +- pytorch_lightning/trainer/evaluation_loop.py | 6 - pytorch_lightning/trainer/trainer.py | 6 +- pytorch_lightning/trainer/training_io.py | 15 +- pytorch_lightning/trainer/training_loop.py | 21 ++- tests/models/base.py | 1 + tests/models/utils.py | 1 + tests/test_gpu_models.py | 77 +++----- tests/trainer/test_callbacks.py | 3 + tests/trainer/test_trainer.py | 71 ++++---- 14 files changed, 207 insertions(+), 193 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f248b390c2..0adee16eff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950)) - Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903)) - Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029)) +- Checkpoint and early stopping now work without val step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041)) ### Changed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8f379ce972..2bff21cb78 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -1,18 +1,12 @@ -r""" -Model Checkpoint -============== -Save the model as often as requested. - -""" - import os -import glob +import shutil import logging as log import warnings +import re import numpy as np -from .base import Callback +from pytorch_lightning.callbacks.base import Callback class ModelCheckpoint(Callback): @@ -20,21 +14,23 @@ class ModelCheckpoint(Callback): Save the model after every epoch. Args: - dirpath: path to save the model file. + filepath: path to save the model file. Can contain named formatting options to be auto-filled. Example:: - # save epoch and val_loss in name - ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5') + # no path + ModelCheckpoint() + # saves like /my/path/epoch_0.ckpt - # saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt - # if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt + # save any arbitrary metrics like and val_loss, etc in name + ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}') + # saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt - monitor: quantity to monitor. - verbose: verbosity mode, False or True. - save_top_k: if `save_top_k == k`, + monitor (str): quantity to monitor. + verbose (bool): verbosity mode, False or True. + save_top_k (int): if `save_top_k == k`, the best k models according to the quantity monitored will be saved. if ``save_top_k == 0``, no models are saved. @@ -43,7 +39,7 @@ class ModelCheckpoint(Callback): if ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with `v0`. - mode: one of {auto, min, max}. + mode (str): one of {auto, min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the @@ -51,46 +47,43 @@ class ModelCheckpoint(Callback): this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. - save_weights_only: if True, then only the model's weights will be + save_weights_only (bool): if True, then only the model's weights will be saved (`model.save_weights(filepath)`), else the full model is saved (`model.save(filepath)`). - period: Interval (number of epochs) between checkpoints. - prefix: String name for particular model + period (int): Interval (number of epochs) between checkpoints. - Example: + Example:: from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to my_path whenever 'val_loss' has a new min - checkpoint_callback = ModelCheckpoint('my_path') + checkpoint_callback = ModelCheckpoint(filepath='my_path') Trainer(checkpoint_callback=checkpoint_callback) - """ - #: checkpoint extension - EXTENSION = '.ckpt' - def __init__( - self, - dirpath: str, - monitor: str = 'val_loss', - verbose: bool = False, - save_top_k: int = 1, - save_weights_only: bool = False, - mode: str = 'auto', - period: int = 1, - prefix: str = '' - ): + # save epoch and val_loss in name + ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}') + # saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt + """ + + def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False, + save_top_k: int = 1, save_weights_only: bool = False, + mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() - if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0: + if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: warnings.warn( - f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0." + f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" ) self.monitor = monitor self.verbose = verbose - self.dirpath = dirpath - os.makedirs(dirpath, exist_ok=True) + if os.path.isdir(filepath): + self.dirpath, self.filename = filepath, '{epoch}' + else: + self.dirpath, self.filename = os.path.split(filepath) + + os.makedirs(self.dirpath, exist_ok=True) self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period @@ -102,14 +95,6 @@ class ModelCheckpoint(Callback): self.best = 0 self.save_function = None - # this create unique prefix if the give already exists - existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION))) - existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints) - version_cnt = 0 - while self.prefix in existing_names: - self.prefix = f'{prefix}-v{version_cnt}' - version_cnt += 1 - mode_dict = { 'min': (np.less, np.Inf, 'min'), 'max': (np.greater, -np.Inf, 'max'), @@ -125,39 +110,65 @@ class ModelCheckpoint(Callback): self.monitor_op, self.kth_value, self.mode = mode_dict[mode] - def _del_model(self, filepath: str) -> None: - # shutil.rmtree(filepath) + def _del_model(self, filepath): os.remove(filepath) - def _save_model(self, filepath: str) -> None: + def _save_model(self, filepath): # make paths - os.makedirs(self.dirpath, exist_ok=True) + os.makedirs(os.path.dirname(filepath), exist_ok=True) # delegate the saving to the model if self.save_function is not None: self.save_function(filepath) else: - raise ValueError("Method `.save_function()` not set") + raise ValueError(".save_function() not set") - def check_monitor_top_k(self, current: float) -> bool: + def check_monitor_top_k(self, current): less_than_k_models = len(self.best_k_models) < self.save_top_k if less_than_k_models: return True return self.monitor_op(current, self.best_k_models[self.kth_best_model]) - def _get_available_filepath(self, current: float, epoch: int) -> str: - current_str = f'{current:.2f}' if current else 'NaN' - fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}' - filepath = os.path.join(self.dirpath, fname + self.EXTENSION) - assert not os.path.isfile(filepath) + def format_checkpoint_name(self, epoch, metrics, ver=None): + """Generate a filename according define template. + + Examples + -------- + >>> tmpdir = os.path.dirname(__file__) + >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}')) + >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + 'epoch=0.ckpt' + >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}')) + >>> os.path.basename(ckpt.format_checkpoint_name(5, {})) + 'epoch=005.ckpt' + >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}')) + >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456))) + 'epoch=2-val_loss=0.12.ckpt' + >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}')) + >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + 'missing=0.ckpt' + """ + # check if user passed in keys to the string + groups = re.findall(r'(\{.*?)[:\}]', self.filename) + + if len(groups) == 0: + # default name + filename = f'{self.prefix}_ckpt_epoch_{epoch}' + else: + metrics['epoch'] = epoch + filename = self.filename + for tmp in groups: + name = tmp[1:] + filename = filename.replace(tmp, name + '={' + name) + if name not in metrics: + metrics[name] = 0 + filename = filename.format(**metrics) + str_ver = f'_v{ver}' if ver is not None else '' + filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt') return filepath - def on_validation_end(self, trainer, pl_module) -> None: - # only run on main process - if trainer.proc_rank != 0: - return - - logs = trainer.callback_metrics + def on_validation_end(self, trainer, pl_module): + metrics = trainer.callback_metrics epoch = trainer.current_epoch self.epochs_since_last_check += 1 @@ -166,27 +177,36 @@ class ModelCheckpoint(Callback): return if self.epochs_since_last_check >= self.period: self.epochs_since_last_check = 0 - current = logs.get(self.monitor) - filepath = self._get_available_filepath(current, epoch) + + filepath = self.format_checkpoint_name(epoch, metrics) + version_cnt = 0 + while os.path.isfile(filepath): + filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt) + # this epoch called before + version_cnt += 1 if self.save_top_k != -1: + current = metrics.get(self.monitor) if current is None: - warnings.warn(f'Can save best model only with {self.monitor} available,' - ' skipping.', RuntimeWarning) + warnings.warn( + f'Can save best model only with {self.monitor} available,' + ' skipping.', RuntimeWarning) else: if self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch) else: if self.verbose > 0: - log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k) + log.info( + f'\nEpoch {epoch:05d}: {self.monitor}' + f' was not in top {self.save_top_k}') else: if self.verbose > 0: - log.info('Epoch %05d: saving model to %s', epoch, filepath) + log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') self._save_model(filepath) - def _do_check_save(self, filepath: str, current: float, epoch: int) -> None: + def _do_check_save(self, filepath, current, epoch): # remove kth if len(self.best_k_models) == self.save_top_k: delpath = self.kth_best_model @@ -205,6 +225,8 @@ class ModelCheckpoint(Callback): self.best = _op(self.best_k_models.values()) if self.verbose > 0: - log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i', - epoch, self.monitor, current, self.best, filepath, self.save_top_k) + log.info( + f'\nEpoch {epoch:05d}: {self.monitor} reached' + f' {current:0.5f} (best {self.best:0.5f}), saving model to' + f' {filepath} as top {self.save_top_k}') self._save_model(filepath) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 349360d7f0..395f2f531e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -68,19 +68,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): #: True if using amp self.use_amp = False - @property - def hparams(self) -> Namespace: - if not hasattr(self, '_hparams'): - return Namespace() - assert isinstance(self._hparams, dict) - return Namespace(**self._hparams) - - @hparams.setter - def hparams(self, params: Union[Dict[str, Any], Namespace]) -> None: - """Set the model hyper-parameters.""" - if isinstance(params, Namespace): - params = vars(params) - self._hparams = params + self.hparams = None def print(self, *args, **kwargs): r""" diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index d6559e22dd..5295bee02f 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -46,6 +46,10 @@ class LightningLoggerBase(ABC): # in case converting from namespace if isinstance(params, Namespace): params = vars(params) + + if params is None: + params = {} + return params @abstractmethod diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 4560e2bca4..6006702838 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -48,9 +48,15 @@ class TrainerCallbackConfigMixin(ABC): else: ckpt_path = os.path.join(self.default_save_path, "checkpoints") + # when no val step is defined, use 'loss' otherwise 'val_loss' + train_step_only = not self.is_overriden('validation_step') + monitor_key = 'loss' if train_step_only else 'val_loss' + self.ckpt_path = ckpt_path + os.makedirs(ckpt_path, exist_ok=True) self.checkpoint_callback = ModelCheckpoint( - dirpath=ckpt_path + filepath=ckpt_path, + monitor=monitor_key ) elif self.checkpoint_callback is False: self.checkpoint_callback = None diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 4d9e1df3d4..fbd1e05496 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -165,7 +165,6 @@ class TrainerEvaluationLoopMixin(ABC): process_output: ... training_tqdm_dict: ... proc_rank: int - checkpoint_callback: ... current_epoch: int callback_metrics: ... test_dataloaders: DataLoader @@ -377,11 +376,6 @@ class TrainerEvaluationLoopMixin(ABC): # Validation/Test end callbacks if test_mode: self.on_test_end() - else: - # model checkpointing - if self.checkpoint_callback is not None: - self.checkpoint_callback.on_validation_end(self, self.get_model()) - self.on_validation_end() def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): # make dataloader_idx arg in validation_step optional diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 04e9010c7a..0558819f49 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1132,9 +1132,6 @@ class Trainer(TrainerIOMixin, # wait for all processes to catch up torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine") - # set up checkpoint callback - self.configure_checkpoint_callback() - # register auto-resubmit when on SLURM self.register_slurm_signal_handlers() @@ -1151,6 +1148,9 @@ class Trainer(TrainerIOMixin, # if cluster resets state, the model will update with the saved weights self.model = model + # set up checkpoint callback + self.configure_checkpoint_callback() + # restore training and model before hpc call self.restore_weights(model) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index ab63dadb19..3f0cc4ed92 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -165,14 +165,15 @@ class TrainerIOMixin(ABC): def save_checkpoint(self, filepath): checkpoint = self.dump_checkpoint() - # do the actual save - try: - self._atomic_save(checkpoint, filepath) - except AttributeError: - if 'hparams' in checkpoint: - del checkpoint['hparams'] + if self.proc_rank == 0: + # do the actual save + try: + self._atomic_save(checkpoint, filepath) + except AttributeError: + if 'hparams' in checkpoint: + del checkpoint['hparams'] - self._atomic_save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) def restore(self, checkpoint_path, on_gpu): """ diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 94c9de74d4..82b4ae14ca 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -203,6 +203,7 @@ class TrainerTrainLoopMixin(ABC): max_steps: int max_steps: int total_batch_idx: int + checkpoint_callback: ... # Callback system callbacks: List[Callback] @@ -212,6 +213,7 @@ class TrainerTrainLoopMixin(ABC): on_batch_end: Callable on_epoch_start: Callable on_epoch_end: Callable + on_validation_end: Callable @property def max_nb_epochs(self): @@ -454,9 +456,6 @@ class TrainerTrainLoopMixin(ABC): if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) - if self.enable_early_stop: - self.early_stop_callback.check_metrics(self.callback_metrics) - # when logs should be saved should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: @@ -469,6 +468,17 @@ class TrainerTrainLoopMixin(ABC): # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) + # --------------- + # CHECKPOINTING, EARLY STOPPING + # --------------- + # save checkpoint even when no test or val step are defined + train_step_only = not self.is_overriden('validation_step') + if self.fast_dev_run or should_check_val or train_step_only: + self.call_checkpoint_callback() + + if self.enable_early_stop: + self.early_stop_callback.check_metrics(self.callback_metrics) + # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 @@ -705,3 +715,8 @@ class TrainerTrainLoopMixin(ABC): output = self.process_output(output, train=True) return output + + def call_checkpoint_callback(self): + if self.checkpoint_callback is not None: + self.checkpoint_callback.on_validation_end(self, self.get_model()) + self.on_validation_end() diff --git a/tests/models/base.py b/tests/models/base.py index dce2ef624e..29fc2177b1 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -46,6 +46,7 @@ class DictHparamsModel(LightningModule): def __init__(self, hparams: Dict): super(DictHparamsModel, self).__init__() + self.hparams = hparams self.l1 = torch.nn.Linear(hparams.get('in_features'), hparams['out_features']) def forward(self, x): diff --git a/tests/models/utils.py b/tests/models/utils.py index 9b5fad79bf..8d17984b94 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -239,5 +239,6 @@ def set_random_master_port(): def init_checkpoint_callback(logger, path_dir=None): exp_path = get_data_path(logger, path_dir=path_dir) ckpt_dir = os.path.join(exp_path, 'checkpoints') + os.mkdir(ckpt_dir) checkpoint = ModelCheckpoint(ckpt_dir) return checkpoint diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index 5fd37870ee..1cdc91fff7 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -256,66 +256,57 @@ def mocked_device_count_0(monkeypatch): monkeypatch.setattr(torch.cuda, 'device_count', device_count) -test_num_gpus_data = [ +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], [ pytest.param(None, 0, None, id="None - expect 0 gpu to use."), pytest.param(0, 0, None, id="Oth gpu, expect 1 gpu to use."), pytest.param(1, 1, None, id="1st gpu, expect 1 gpu to use."), pytest.param(-1, PRETEND_N_OF_GPUS, "ddp", id="-1 - use all gpus"), pytest.param('-1', PRETEND_N_OF_GPUS, "ddp", id="'-1' - use all gpus"), pytest.param(3, 3, "ddp", id="3rd gpu - 1 gpu to use (backend:ddp)") -] - - -@pytest.mark.gpus_param_tests -@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], test_num_gpus_data) +]) def test_trainer_gpu_parse(mocked_device_count, gpus, expected_num_gpus, distributed_backend): assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus -test_num_gpus_data_0 = [ +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], [ pytest.param(None, 0, None, id="None - expect 0 gpu to use."), pytest.param(None, 0, "ddp", id="None - expect 0 gpu to use."), -] - - -@pytest.mark.gpus_param_tests -@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], test_num_gpus_data_0) +]) def test_trainer_num_gpu_0(mocked_device_count_0, gpus, expected_num_gpus, distributed_backend): assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus -test_root_gpu_data = [ +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], [ pytest.param(None, None, "ddp", id="None is None"), pytest.param(0, None, "ddp", id="O gpus, expect gpu root device to be None."), pytest.param(1, 0, "ddp", id="1 gpu, expect gpu root device to be 0."), pytest.param(-1, 0, "ddp", id="-1 - use all gpus, expect gpu root device to be 0."), pytest.param('-1', 0, "ddp", id="'-1' - use all gpus, expect gpu root device to be 0."), - pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)")] - - -@pytest.mark.gpus_param_tests -@pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data) + pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)") +]) def test_root_gpu_property(mocked_device_count, gpus, expected_root_gpu, distributed_backend): assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu -test_root_gpu_data_for_0_devices_passing = [ +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize([ + 'gpus', 'expected_root_gpu', "distributed_backend"], [ pytest.param(None, None, None, id="None is None"), pytest.param(None, None, "ddp", id="None is None"), pytest.param(0, None, "ddp", id="None is None"), -] - - -@pytest.mark.gpus_param_tests -@pytest.mark.parametrize([ - 'gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data_for_0_devices_passing) +]) def test_root_gpu_property_0_passing( mocked_device_count_0, gpus, expected_root_gpu, distributed_backend): assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu # Asking for a gpu when non are available will result in a MisconfigurationException -test_root_gpu_data_for_0_devices_raising = [ +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize([ + 'gpus', 'expected_root_gpu', "distributed_backend"], [ pytest.param(1, None, "ddp"), pytest.param(3, None, "ddp"), pytest.param(3, None, "ddp"), @@ -323,34 +314,27 @@ test_root_gpu_data_for_0_devices_raising = [ pytest.param([0, 1], None, "ddp"), pytest.param(-1, None, "ddp"), pytest.param('-1', None, "ddp") -] - - -@pytest.mark.gpus_param_tests -@pytest.mark.parametrize([ - 'gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data_for_0_devices_raising) +]) def test_root_gpu_property_0_raising( mocked_device_count_0, gpus, expected_root_gpu, distributed_backend): with pytest.raises(MisconfigurationException): Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu -test_determine_root_gpu_device_data = [ +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(['gpus', 'expected_root_gpu'], [ pytest.param(None, None, id="No gpus, expect gpu root device to be None"), pytest.param([0], 0, id="Oth gpu, expect gpu root device to be 0."), pytest.param([1], 1, id="1st gpu, expect gpu root device to be 1."), pytest.param([3], 3, id="3rd gpu, expect gpu root device to be 3."), pytest.param([1, 2], 1, id="[1, 2] gpus, expect gpu root device to be 1."), -] - - -@pytest.mark.gpus_param_tests -@pytest.mark.parametrize(['gpus', 'expected_root_gpu'], test_determine_root_gpu_device_data) +]) def test_determine_root_gpu_device(gpus, expected_root_gpu): assert determine_root_gpu_device(gpus) == expected_root_gpu -test_parse_gpu_ids_data = [ +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(['gpus', 'expected_gpu_ids'], [ pytest.param(None, None), pytest.param(0, None), pytest.param(1, [0]), @@ -362,16 +346,13 @@ test_parse_gpu_ids_data = [ pytest.param('3', [3]), pytest.param('1, 3', [1, 3]), pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"), -] - - -@pytest.mark.gpus_param_tests -@pytest.mark.parametrize(['gpus', 'expected_gpu_ids'], test_parse_gpu_ids_data) +]) def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids): assert parse_gpu_ids(gpus) == expected_gpu_ids -test_parse_gpu_invalid_inputs_data = [ +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(['gpus'], [ pytest.param(0.1), pytest.param(-2), pytest.param(False), @@ -380,11 +361,7 @@ test_parse_gpu_invalid_inputs_data = [ pytest.param([None]), pytest.param(['0']), pytest.param((0, 1)), -] - - -@pytest.mark.gpus_param_tests -@pytest.mark.parametrize(['gpus'], test_parse_gpu_invalid_inputs_data) +]) def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus): with pytest.raises(MisconfigurationException): parse_gpu_ids(gpus) diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 23cea98107..0c69c155a8 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,5 +1,8 @@ +import os + import tests.models.utils as tutils from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.callbacks import ModelCheckpoint from tests.models import ( TestModelBase, LightTrainDataloader, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1336e2c289..349b625bce 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -27,6 +27,28 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.utilities.debugging import MisconfigurationException +def test_hparams_save_load(tmpdir): + model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10}) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=2, + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + assert result == 1 + + # try to load the model now + pretrained_model = tutils.load_model_from_checkpoint( + trainer.checkpoint_callback.dirpath, + module_class=DictHparamsModel + ) + + def test_no_val_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" tutils.reset_seed() @@ -126,7 +148,8 @@ def test_gradient_accumulation_scheduling(tmpdir): assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5}) # test optimizer call freq matches scheduler - def _optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): + def _optimizer_step(self, epoch, batch_idx, optimizer, + optimizer_idx, second_order_closure=None): # only test the first 12 batches in epoch if batch_idx < 12: if epoch == 0: @@ -255,11 +278,11 @@ def test_model_checkpoint_options(tmp_path): assert len(file_lists) == len(losses), "Should save all models when save_top_k=-1" # verify correct naming - for fname in {'_epoch=4_val_loss=2.50.ckpt', - '_epoch=3_val_loss=5.00.ckpt', - '_epoch=2_val_loss=2.80.ckpt', - '_epoch=1_val_loss=9.00.ckpt', - '_epoch=0_val_loss=10.00.ckpt'}: + for fname in {'epoch=4.ckpt', + 'epoch=3.ckpt', + 'epoch=2.ckpt', + 'epoch=1.ckpt', + 'epoch=0.ckpt'}: assert fname in file_lists save_dir = tmp_path / "2" @@ -286,7 +309,7 @@ def test_model_checkpoint_options(tmp_path): # ----------------- # CASE K=1 (2.5, epoch 4) - checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix') + checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix_') checkpoint_callback.save_function = mock_save_function trainer = Trainer() @@ -299,7 +322,7 @@ def test_model_checkpoint_options(tmp_path): file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 1, "Should save 1 model when save_top_k=1" - assert 'test_prefix_epoch=4_val_loss=2.50.ckpt' in file_lists + assert 'test_prefix_epoch=4.ckpt' in file_lists save_dir = tmp_path / "4" save_dir.mkdir() @@ -322,8 +345,8 @@ def test_model_checkpoint_options(tmp_path): file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2' - for fname in {'_epoch=4_val_loss=2.50.ckpt', - '_epoch=2_val_loss=2.80.ckpt', + for fname in {'epoch=4.ckpt', + 'epoch=2.ckpt', 'other_file.ckpt'}: assert fname in file_lists @@ -368,9 +391,9 @@ def test_model_checkpoint_options(tmp_path): file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3' - for fname in {'_epoch=0_val_loss=2.80.ckpt', - '_epoch=0_val_loss=2.50.ckpt', - '_epoch=0_val_loss=5.00.ckpt'}: + for fname in {'epoch=0.ckpt', + 'epoch=0.ckpt', + 'epoch=0.ckpt'}: assert fname in file_lists @@ -620,25 +643,3 @@ def test_default_args(tmpdir): assert isinstance(trainer, Trainer) assert trainer.max_epochs == 5 - - -def test_hparams_save_load(tmpdir): - model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10}) - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=2, - ) - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - - assert result == 1 - - # try to load the model now - pretrained_model = tutils.load_model_from_checkpoint( - trainer.checkpoint_callback.dirpath, - module_class=DictHparamsModel - )