diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8388c5c4fc..e869e8847a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -6,7 +6,7 @@ Save the model as often as requested. """ import os -import shutil +import glob import logging as log import warnings @@ -20,17 +20,19 @@ class ModelCheckpoint(Callback): Save the model after every epoch. Args: - filepath: path to save the model file. + dirpath: path to save the model file. Can contain named formatting options to be auto-filled. Example:: # save epoch and val_loss in name ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5') - # saves file like: /path/epoch_2-val_loss_0.2.hdf5 - monitor (str): quantity to monitor. - verbose (bool): verbosity mode, False or True. - save_top_k (int): if `save_top_k == k`, + # saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt + # if such model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt + + monitor: quantity to monitor. + verbose: verbosity mode, False or True. + save_top_k: if `save_top_k == k`, the best k models according to the quantity monitored will be saved. if ``save_top_k == 0``, no models are saved. @@ -39,7 +41,7 @@ class ModelCheckpoint(Callback): if ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with `v0`. - mode (str): one of {auto, min, max}. + mode: one of {auto, min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the @@ -47,35 +49,46 @@ class ModelCheckpoint(Callback): this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. - save_weights_only (bool): if True, then only the model's weights will be + save_weights_only: if True, then only the model's weights will be saved (`model.save_weights(filepath)`), else the full model is saved (`model.save(filepath)`). - period (int): Interval (number of epochs) between checkpoints. + period: Interval (number of epochs) between checkpoints. + prefix: String name for particular model - Example:: + Example: from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to my_path whenever 'val_loss' has a new min - checkpoint_callback = ModelCheckpoint(filepath='my_path') + checkpoint_callback = ModelCheckpoint('my_path') Trainer(checkpoint_callback=checkpoint_callback) """ + #: checkpoint extension + EXTENSION = '.ckpt' - def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False, - save_top_k: int = 1, save_weights_only: bool = False, - mode: str = 'auto', period: int = 1, prefix: str = ''): + def __init__( + self, + dirpath: str, + monitor: str = 'val_loss', + verbose: bool = False, + save_top_k: int = 1, + save_weights_only: bool = False, + mode: str = 'auto', + period: int = 1, + prefix: str = '' + ): super().__init__() - if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: + if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0: warnings.warn( - f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." + f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" ) self.monitor = monitor self.verbose = verbose - self.filepath = filepath - os.makedirs(filepath, exist_ok=True) + self.dirpath = dirpath + os.makedirs(dirpath, exist_ok=True) self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period @@ -87,6 +100,14 @@ class ModelCheckpoint(Callback): self.best = 0 self.save_function = None + # this create unique prefix if the give already exists + existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION))) + existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints) + version_cnt = 0 + while self.prefix in existing_names: + self.prefix = f'{prefix}-v{version_cnt}' + version_cnt += 1 + mode_dict = { 'min': (np.less, np.Inf, 'min'), 'max': (np.greater, -np.Inf, 'max'), @@ -102,15 +123,13 @@ class ModelCheckpoint(Callback): self.monitor_op, self.kth_value, self.mode = mode_dict[mode] - def _del_model(self, filepath): - try: - shutil.rmtree(filepath) - except OSError: - os.remove(filepath) + def _del_model(self, filepath: str) -> None: + # shutil.rmtree(filepath) + os.remove(filepath) - def _save_model(self, filepath): + def _save_model(self, filepath: str) -> None: # make paths - os.makedirs(os.path.dirname(filepath), exist_ok=True) + os.makedirs(self.dirpath, exist_ok=True) # delegate the saving to the model if self.save_function is not None: @@ -118,13 +137,20 @@ class ModelCheckpoint(Callback): else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, current): + def check_monitor_top_k(self, current: float) -> bool: less_than_k_models = len(self.best_k_models) < self.save_top_k if less_than_k_models: return True return self.monitor_op(current, self.best_k_models[self.kth_best_model]) - def on_validation_end(self, trainer, pl_module): + def _get_available_filepath(self, current: float, epoch: int) -> str: + current_str = f'{current:.2f}' if current else 'NaN' + fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}' + filepath = os.path.join(self.dirpath, fname + self.EXTENSION) + assert not os.path.isfile(filepath) + return filepath + + def on_validation_end(self, trainer, pl_module) -> None: # only run on main process if trainer.proc_rank != 0: return @@ -138,35 +164,27 @@ class ModelCheckpoint(Callback): return if self.epochs_since_last_check >= self.period: self.epochs_since_last_check = 0 - filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt' - version_cnt = 0 - while os.path.isfile(filepath): - # this epoch called before - filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt' - version_cnt += 1 + current = logs.get(self.monitor) + filepath = self._get_available_filepath(current, epoch) if self.save_top_k != -1: - current = logs.get(self.monitor) if current is None: - warnings.warn( - f'Can save best model only with {self.monitor} available,' - ' skipping.', RuntimeWarning) + warnings.warn(f'Can save best model only with {self.monitor} available,' + ' skipping.', RuntimeWarning) else: if self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch) else: if self.verbose > 0: - log.info( - f'\nEpoch {epoch:05d}: {self.monitor}' - f' was not in top {self.save_top_k}') + log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k) else: if self.verbose > 0: - log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') + log.info('Epoch %05d: saving model to %s', epoch, filepath) self._save_model(filepath) - def _do_check_save(self, filepath, current, epoch): + def _do_check_save(self, filepath: str, current: float, epoch: int) -> None: # remove kth if len(self.best_k_models) == self.save_top_k: delpath = self.kth_best_model @@ -185,8 +203,6 @@ class ModelCheckpoint(Callback): self.best = _op(self.best_k_models.values()) if self.verbose > 0: - log.info( - f'\nEpoch {epoch:05d}: {self.monitor} reached' - f' {current:0.5f} (best {self.best:0.5f}), saving model to' - f' {filepath} as top {self.save_top_k}') + log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i', + epoch, self.monitor, current, self.best, filepath, self.save_top_k) self._save_model(filepath) diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 8d7c3d41c3..4560e2bca4 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -50,7 +50,7 @@ class TrainerCallbackConfigMixin(ABC): self.ckpt_path = ckpt_path self.checkpoint_callback = ModelCheckpoint( - filepath=ckpt_path + dirpath=ckpt_path ) elif self.checkpoint_callback is False: self.checkpoint_callback = None @@ -62,7 +62,7 @@ class TrainerCallbackConfigMixin(ABC): self.checkpoint_callback.save_function = self.save_checkpoint # if checkpoint callback used, then override the weights path - self.weights_save_path = self.checkpoint_callback.filepath + self.weights_save_path = self.checkpoint_callback.dirpath # if weights_save_path is still none here, set to current working dir if self.weights_save_path is None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9b451e8ef8..33daf8de06 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -810,7 +810,7 @@ class Trainer(TrainerIOMixin, self.amp_level = amp_level self.precision = precision - assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported' + assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' if self.precision == 16 and num_tpu_cores is None: use_amp = True diff --git a/tests/models/utils.py b/tests/models/utils.py index db517b3e28..6e4fa52635 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -32,7 +32,7 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50): # test model loading pretrained_model = load_model(trainer.logger, - trainer.checkpoint_callback.filepath, + trainer.checkpoint_callback.dirpath, path_expt=trainer_options.get('default_save_path')) # test new model accuracy @@ -70,7 +70,7 @@ def run_model_test(trainer_options, model, on_gpu=True): assert result == 1, 'amp + ddp model failed to complete' # test model loading - pretrained_model = load_model(logger, trainer.checkpoint_callback.filepath) + pretrained_model = load_model(logger, trainer.checkpoint_callback.dirpath) # test new model accuracy test_loaders = model.test_dataloader() diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index c0c429c240..cf3a6773ca 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -1,3 +1,4 @@ +import glob import logging as log import os @@ -52,7 +53,7 @@ def test_running_test_pretrained_model_ddp(tmpdir): # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model(logger, - trainer.checkpoint_callback.filepath, + trainer.checkpoint_callback.dirpath, module_class=LightningTestModel) # run test set @@ -96,7 +97,7 @@ def test_running_test_pretrained_model(tmpdir): # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model( - logger, trainer.checkpoint_callback.filepath, module_class=LightningTestModel + logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel ) new_trainer = Trainer(**trainer_options) @@ -132,9 +133,7 @@ def test_load_model_from_checkpoint(tmpdir): assert result == 1, 'training failed to complete' # load last checkpoint - last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt") - if not os.path.isfile(last_checkpoint): - last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt") + last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint) # test that hparams loaded correctly @@ -186,7 +185,7 @@ def test_running_test_pretrained_model_dp(tmpdir): # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model(logger, - trainer.checkpoint_callback.filepath, + trainer.checkpoint_callback.dirpath, module_class=LightningTestModel) new_trainer = Trainer(**trainer_options) @@ -346,7 +345,7 @@ def test_load_model_with_missing_hparams(tmpdir): model = LightningTestModelWithoutHyperparametersArg() trainer.fit(model) - last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt") + last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] # try to load a checkpoint that has hparams but model is missing hparams arg with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d95fce3e5c..55eb803b71 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,3 +1,4 @@ +import glob import math import os import pytest @@ -257,8 +258,12 @@ def test_model_checkpoint_options(tmp_path): assert len(file_lists) == len(losses), "Should save all models when save_top_k=-1" # verify correct naming - for i in range(0, len(losses)): - assert f"_ckpt_epoch_{i}.ckpt" in file_lists + for fname in {'_epoch=4_val_loss=2.50.ckpt', + '_epoch=3_val_loss=5.00.ckpt', + '_epoch=2_val_loss=2.80.ckpt', + '_epoch=1_val_loss=9.00.ckpt', + '_epoch=0_val_loss=10.00.ckpt'}: + assert fname in file_lists save_dir = tmp_path / "2" save_dir.mkdir() @@ -297,7 +302,7 @@ def test_model_checkpoint_options(tmp_path): file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 1, "Should save 1 model when save_top_k=1" - assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists + assert 'test_prefix_epoch=4_val_loss=2.50.ckpt' in file_lists save_dir = tmp_path / "4" save_dir.mkdir() @@ -320,9 +325,10 @@ def test_model_checkpoint_options(tmp_path): file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2' - assert '_ckpt_epoch_4.ckpt' in file_lists - assert '_ckpt_epoch_2.ckpt' in file_lists - assert 'other_file.ckpt' in file_lists + for fname in {'_epoch=4_val_loss=2.50.ckpt', + '_epoch=2_val_loss=2.80.ckpt', + 'other_file.ckpt'}: + assert fname in file_lists save_dir = tmp_path / "5" save_dir.mkdir() @@ -365,9 +371,10 @@ def test_model_checkpoint_options(tmp_path): file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3' - assert '_ckpt_epoch_0_v2.ckpt' in file_lists - assert '_ckpt_epoch_0_v1.ckpt' in file_lists - assert '_ckpt_epoch_0.ckpt' in file_lists + for fname in {'_epoch=0_val_loss=2.80.ckpt', + '_epoch=0_val_loss=2.50.ckpt', + '_epoch=0_val_loss=5.00.ckpt'}: + assert fname in file_lists def test_model_freeze_unfreeze(): @@ -388,7 +395,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir): hparams = tutils.get_hparams() - def new_model(): + def _new_model(): # Create a model that tracks epochs and batches seen model = LightningTestModel(hparams) model.num_epochs_seen = 0 @@ -406,7 +413,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir): model.on_batch_start = types.MethodType(increment_batch, model) return model - model = new_model() + model = _new_model() trainer_options = dict( show_progress_bar=False, @@ -417,7 +424,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir): logger=False, default_save_path=tmpdir, early_stop_callback=False, - val_check_interval=0.5, + val_check_interval=1., ) # fit model @@ -430,15 +437,10 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir): assert model.num_batches_seen == training_batches * 2 # Other checkpoints can be uncommented if/when resuming mid-epoch is supported - checkpoints = [ - # os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt"), - os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0_v0.ckpt"), - # os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt"), - os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1_v0.ckpt"), - ] + checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) for check in checkpoints: - next_model = new_model() + next_model = _new_model() state = torch.load(check) # Resume training