From f1e11d8b3874067016693c50ae253ec79eecda09 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 5 Apr 2020 09:56:26 -0400 Subject: [PATCH] model_checkpoint to save all models (#1359) * model_checkpoint to save all models * changelog * rise if Co-authored-by: jamesjjcondon Co-authored-by: J. Borovec --- CHANGELOG.md | 3 ++- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c329a33e91..f445e99417 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,7 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)). +- Fixed `model_checkpoint` when saving all models ([#1359](https://github.com/PyTorchLightning/pytorch-lightning/pull/1359)) +- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)) - Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114)) - Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132)) - Fixed a bug that created an extra dataloader with active `reload_dataloaders_every_epoch` ([#1181](https://github.com/PyTorchLightning/pytorch-lightning/issues/1181) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 90f649394f..5a2fbb1ce6 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -82,7 +82,7 @@ class ModelCheckpoint(Callback): save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() - if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: + if save_top_k > 0 and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: warnings.warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" @@ -219,7 +219,7 @@ class ModelCheckpoint(Callback): def _do_check_save(self, filepath, current, epoch): # remove kth - if len(self.best_k_models) == self.save_top_k: + if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0: delpath = self.kth_best_model self.best_k_models.pop(self.kth_best_model) self._del_model(delpath)