diff --git a/CHANGELOG.md b/CHANGELOG.md index 031b573461..edd196cc32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -382,6 +382,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Move `Strategy.process_dataloader` function call from `fit/evaluation/predict_loop.py` to `data_connector.py` ([#12251](https://github.com/PyTorchLightning/pytorch-lightning/pull/12251)) +- `ModelCheckpoint(save_last=True, every_n_epochs=N)` now saves a "last" checkpoint every epoch (disregarding `every_n_epochs`) instead of only once at the end of training ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418)) + + - The strategies that support `sync_batchnorm` now only apply it when fitting ([#11919](https://github.com/PyTorchLightning/pytorch-lightning/pull/11919)) @@ -861,6 +864,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249)) +- Fixed bug where the global step tracked by `ModelCheckpoint` was still set even if no checkpoint was saved ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418)) +- + +- Fixed bug where `ModelCheckpoint` was overriding the `epoch` and `step` logged values ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418)) + + +- Fixed bug where monitoring the default `epoch` and `step` values with `ModelCheckpoint` would fail ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418)) + + - Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 57ed9098bd..ac05e6f66e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -100,8 +100,8 @@ class ModelCheckpoint(Callback): based on either the maximization or the minimization of the monitored quantity. For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc. auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name. - For example, ``filename='checkpoint_{epoch:02d}-{acc:02d}`` with epoch 1 and acc 80 will resolve to - ``checkpoint_epoch=01-acc=80.ckp``. Is useful to set it to ``False`` when metric names contain ``/`` + For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve + to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/`` as this will result in extra folders. save_weights_only: if ``True``, then only the model's weights will be saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too. @@ -116,7 +116,8 @@ class ModelCheckpoint(Callback): This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. every_n_epochs: Number of epochs between checkpoints. This value must be ``None`` or non-negative. - To disable saving after each epoch, set ``every_n_epochs = 0``. + To disable saving top-k checkpoints, set ``every_n_epochs = 0``. + This argument does not impact the saving of ``save_last=True`` checkpoints. If all of ``every_n_epochs``, ``every_n_train_steps`` and ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch (equivalent to ``every_n_epochs = 1``). @@ -295,28 +296,25 @@ class ModelCheckpoint(Callback): if not skip_time: self._last_time_checked = now - self.save_checkpoint(trainer) + monitor_candidates = self._monitor_candidates(trainer) + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" - if ( - not self._should_skip_saving_checkpoint(trainer) - and self._save_on_train_epoch_end - and self._every_n_epochs > 0 - and (trainer.current_epoch + 1) % self._every_n_epochs == 0 - ): - self.save_checkpoint(trainer) + if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end: + monitor_candidates = self._monitor_candidates(trainer) + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the validation stage.""" - if ( - self._should_skip_saving_checkpoint(trainer) - or self._save_on_train_epoch_end - or self._every_n_epochs < 1 - or (trainer.current_epoch + 1) % self._every_n_epochs != 0 - ): - return - self.save_checkpoint(trainer) + if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end: + monitor_candidates = self._monitor_candidates(trainer) + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] @@ -352,26 +350,41 @@ class ModelCheckpoint(Callback): self.last_model_path = callback_state.get("last_model_path", self.last_model_path) self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer: "pl.Trainer") -> None: + def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover """Performs the main logic around saving a checkpoint. This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. """ - self._validate_monitor_key(trainer) - - # what can be monitored - monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=trainer.global_step) - - # callback supports multiple simultaneous modes - # here we call each mode sequentially - # Mode 1: save the top k checkpoints - self._save_top_k_checkpoint(trainer, monitor_candidates) - # Mode 2: save monitor=None checkpoints - self._save_none_monitor_checkpoint(trainer, monitor_candidates) - # Mode 3: save last checkpoints + # TODO: unused method. deprecate it + monitor_candidates = self._monitor_candidates(trainer) + self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + if self.save_top_k == 0: + return + + # validate metric + if self.monitor is not None: + if self.monitor not in monitor_candidates: + m = ( + f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned" + f" metrics: {list(monitor_candidates)}." + f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?" + ) + if trainer.fit_loop.epoch_loop.val_loop._has_run: + raise MisconfigurationException(m) + warning_cache.warn(m) + self._save_monitor_checkpoint(trainer, monitor_candidates) + else: + self._save_none_monitor_checkpoint(trainer, monitor_candidates) + + def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + trainer.save_checkpoint(filepath, self.save_weights_only) + + self._last_global_step_saved = trainer.global_step + # notify loggers if trainer.is_global_zero: for logger in trainer.loggers: @@ -594,21 +607,6 @@ class ModelCheckpoint(Callback): if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") - def _validate_monitor_key(self, trainer: "pl.Trainer") -> None: - metrics = trainer.callback_metrics - - # validate metric - if self.monitor is not None and not self._is_valid_monitor_key(metrics): - m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" - f" {list(metrics.keys())}. " - f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?" - ) - if not trainer.fit_loop.epoch_loop.val_loop._has_run: - warning_cache.warn(m) - else: - raise MisconfigurationException(m) - def _get_metric_interpolated_filepath_name( self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None ) -> str: @@ -621,51 +619,46 @@ class ModelCheckpoint(Callback): return filepath - def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> Dict[str, _METRIC]: + def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]: monitor_candidates = deepcopy(trainer.callback_metrics) - monitor_candidates.update(epoch=epoch, step=step) + # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor + # or does not exist we overwrite it as it's likely an error + epoch = monitor_candidates.get("epoch") + monitor_candidates["epoch"] = ( + epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch) + ) + step = monitor_candidates.get("step") + monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step) return monitor_candidates def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: if not self.save_last: return - self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step) filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST) # set the last model path before saving because it will be part of the state. previous, self.last_model_path = self.last_model_path, filepath - trainer.save_checkpoint(filepath, self.save_weights_only) + self._save_checkpoint(trainer, filepath) if previous and previous != filepath: trainer.strategy.remove_checkpoint(previous) - def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: - if self.monitor is None or self.save_top_k == 0: - return - self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step) - + def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: current = monitor_candidates.get(self.monitor) if self.check_monitor_top_k(trainer, current): self._update_best_and_save(current, trainer, monitor_candidates) elif self.verbose: - epoch = monitor_candidates.get("epoch") - step = monitor_candidates.get("step") - rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor} was not in top {self.save_top_k}") + epoch = monitor_candidates["epoch"] + step = monitor_candidates["step"] + rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: - if self.monitor is not None or self.save_top_k == 0: - return - self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step) - filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath - trainer.save_checkpoint(filepath, self.save_weights_only) + self._save_checkpoint(trainer, filepath) if self.save_top_k == 1 and previous and previous != filepath: trainer.strategy.remove_checkpoint(previous) - def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool: - return self.monitor in metrics or len(metrics) == 0 - def _update_best_and_save( self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC] ) -> None: @@ -697,13 +690,13 @@ class ModelCheckpoint(Callback): self.best_model_score = self.best_k_models[self.best_model_path] if self.verbose: - epoch = monitor_candidates.get("epoch") - step = monitor_candidates.get("step") + epoch = monitor_candidates["epoch"] + step = monitor_candidates["step"] rank_zero_info( - f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" - f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' + f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}" + f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}" ) - trainer.save_checkpoint(filepath, self.save_weights_only) + self._save_checkpoint(trainer, filepath) if del_filepath is not None and filepath != del_filepath: trainer.strategy.remove_checkpoint(del_filepath) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3dadf0b733..544cbb5aff 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -109,7 +109,6 @@ def test_model_checkpoint_score_and_ckpt( def validation_step(self, batch, batch_idx): log_value = self.val_logs[self.current_epoch, batch_idx] self.log("val_log", log_value) - self.log("epoch", self.current_epoch, on_epoch=True) return super().validation_step(batch, batch_idx) def configure_optimizers(self): @@ -1086,7 +1085,7 @@ def test_hparams_type(tmpdir, use_omegaconf): super().__init__() self.save_hyperparameters(hparams) - model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo") + model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1) trainer = Trainer( max_epochs=1, default_root_dir=tmpdir, @@ -1281,3 +1280,24 @@ def test_last_global_step_saved(): trainer.callback_metrics = {"foo": 123} model_checkpoint.save_checkpoint(trainer) assert model_checkpoint._last_global_step_saved == 0 + + +@pytest.mark.parametrize("every_n_epochs", (0, 5)) +def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs): + """Test that `save_last` ignores `every_n_epochs`.""" + mc = ModelCheckpoint(every_n_epochs=every_n_epochs, save_last=True, save_top_k=0, save_on_train_epoch_end=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + callbacks=mc, + limit_train_batches=1, + limit_val_batches=0, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + model = BoringModel() + with patch.object(trainer, "save_checkpoint") as save_mock: + trainer.fit(model) + assert mc.last_model_path # a "last" ckpt was saved + assert save_mock.call_count == trainer.max_epochs diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 5ce5ceb75a..77afe361b0 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -159,8 +159,7 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"} class CustomModel(BoringModel): - def training_epoch_end(self, *args, **kwargs): - super().training_epoch_end(*args, **kwargs) + def on_train_epoch_end(self, *args, **kwargs): self.log("epoch", self.current_epoch) model = CustomModel() diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index e5259c4047..e5c91e0f71 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -106,7 +106,7 @@ class CustomClassificationModelDP(ClassificationModel): def test_model_properties_fit_ckpt_path(tmpdir): """Test that properties like `current_epoch` and `global_step` in model and trainer are always the same.""" model = BoringModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) trainer_args = dict( default_root_dir=tmpdir, max_epochs=1, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index fca1aa0c37..314e4f3578 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -403,6 +403,7 @@ def test_model_freeze_unfreeze(): assert param.requires_grad +@pytest.mark.xfail(reason="FIXME(@carmocca): this test wasn't running and is now broken") @pytest.mark.parametrize("url_ckpt", [True, False]) def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs.""" @@ -429,7 +430,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck max_epochs=2, limit_train_batches=0.65, limit_val_batches=1, - callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=-1)], + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)], default_root_dir=tmpdir, val_check_interval=1.0, enable_progress_bar=False, @@ -449,6 +450,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck ip, port = tmpdir_server checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints] + assert checkpoints for ckpt in checkpoints: next_model = TestModel() state = pl_load(ckpt)