`ModelCheckpoint`'s `save_last` now ignores `every_n_epochs` (#12418)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Carlos Mocholí 2022-03-24 20:06:52 +01:00 committed by GitHub
parent dcc973e019
commit 71e0ddb62f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 104 additions and 78 deletions

View File

@ -382,6 +382,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Move `Strategy.process_dataloader` function call from `fit/evaluation/predict_loop.py` to `data_connector.py` ([#12251](https://github.com/PyTorchLightning/pytorch-lightning/pull/12251))
- `ModelCheckpoint(save_last=True, every_n_epochs=N)` now saves a "last" checkpoint every epoch (disregarding `every_n_epochs`) instead of only once at the end of training ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
- The strategies that support `sync_batchnorm` now only apply it when fitting ([#11919](https://github.com/PyTorchLightning/pytorch-lightning/pull/11919))
@ -861,6 +864,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249))
- Fixed bug where the global step tracked by `ModelCheckpoint` was still set even if no checkpoint was saved ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
-
- Fixed bug where `ModelCheckpoint` was overriding the `epoch` and `step` logged values ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
- Fixed bug where monitoring the default `epoch` and `step` values with `ModelCheckpoint` would fail ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
- Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267))

View File

@ -100,8 +100,8 @@ class ModelCheckpoint(Callback):
based on either the maximization or the minimization of the monitored quantity.
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
For example, ``filename='checkpoint_{epoch:02d}-{acc:02d}`` with epoch 1 and acc 80 will resolve to
``checkpoint_epoch=01-acc=80.ckp``. Is useful to set it to ``False`` when metric names contain ``/``
For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
as this will result in extra folders.
save_weights_only: if ``True``, then only the model's weights will be
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
@ -116,7 +116,8 @@ class ModelCheckpoint(Callback):
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
every_n_epochs: Number of epochs between checkpoints.
This value must be ``None`` or non-negative.
To disable saving after each epoch, set ``every_n_epochs = 0``.
To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
This argument does not impact the saving of ``save_last=True`` checkpoints.
If all of ``every_n_epochs``, ``every_n_train_steps`` and
``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
(equivalent to ``every_n_epochs = 1``).
@ -295,28 +296,25 @@ class ModelCheckpoint(Callback):
if not skip_time:
self._last_time_checked = now
self.save_checkpoint(trainer)
monitor_candidates = self._monitor_candidates(trainer)
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
if (
not self._should_skip_saving_checkpoint(trainer)
and self._save_on_train_epoch_end
and self._every_n_epochs > 0
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
self.save_checkpoint(trainer)
if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end:
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the validation stage."""
if (
self._should_skip_saving_checkpoint(trainer)
or self._save_on_train_epoch_end
or self._every_n_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_epochs != 0
):
return
self.save_checkpoint(trainer)
if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end:
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
@ -352,26 +350,41 @@ class ModelCheckpoint(Callback):
self.last_model_path = callback_state.get("last_model_path", self.last_model_path)
self.best_model_path = callback_state["best_model_path"]
def save_checkpoint(self, trainer: "pl.Trainer") -> None:
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
"""Performs the main logic around saving a checkpoint.
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
"""
self._validate_monitor_key(trainer)
# what can be monitored
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=trainer.global_step)
# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save the top k checkpoints
self._save_top_k_checkpoint(trainer, monitor_candidates)
# Mode 2: save monitor=None checkpoints
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
# Mode 3: save last checkpoints
# TODO: unused method. deprecate it
monitor_candidates = self._monitor_candidates(trainer)
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.save_top_k == 0:
return
# validate metric
if self.monitor is not None:
if self.monitor not in monitor_candidates:
m = (
f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"
f" metrics: {list(monitor_candidates)}."
f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
)
if trainer.fit_loop.epoch_loop.val_loop._has_run:
raise MisconfigurationException(m)
warning_cache.warn(m)
self._save_monitor_checkpoint(trainer, monitor_candidates)
else:
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.save_checkpoint(filepath, self.save_weights_only)
self._last_global_step_saved = trainer.global_step
# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
@ -594,21 +607,6 @@ class ModelCheckpoint(Callback):
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
metrics = trainer.callback_metrics
# validate metric
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. "
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
)
if not trainer.fit_loop.epoch_loop.val_loop._has_run:
warning_cache.warn(m)
else:
raise MisconfigurationException(m)
def _get_metric_interpolated_filepath_name(
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
) -> str:
@ -621,51 +619,46 @@ class ModelCheckpoint(Callback):
return filepath
def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> Dict[str, _METRIC]:
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
monitor_candidates = deepcopy(trainer.callback_metrics)
monitor_candidates.update(epoch=epoch, step=step)
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
# or does not exist we overwrite it as it's likely an error
epoch = monitor_candidates.get("epoch")
monitor_candidates["epoch"] = (
epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch)
)
step = monitor_candidates.get("step")
monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step)
return monitor_candidates
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if not self.save_last:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
trainer.save_checkpoint(filepath, self.save_weights_only)
self._save_checkpoint(trainer, filepath)
if previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.monitor is None or self.save_top_k == 0:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
current = monitor_candidates.get(self.monitor)
if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates)
elif self.verbose:
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor} was not in top {self.save_top_k}")
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.monitor is not None or self.save_top_k == 0:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
# set the best model path before saving because it will be part of the state.
previous, self.best_model_path = self.best_model_path, filepath
trainer.save_checkpoint(filepath, self.save_weights_only)
self._save_checkpoint(trainer, filepath)
if self.save_top_k == 1 and previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool:
return self.monitor in metrics or len(metrics) == 0
def _update_best_and_save(
self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
) -> None:
@ -697,13 +690,13 @@ class ModelCheckpoint(Callback):
self.best_model_score = self.best_k_models[self.best_model_path]
if self.verbose:
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
rank_zero_info(
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}"
f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}"
)
trainer.save_checkpoint(filepath, self.save_weights_only)
self._save_checkpoint(trainer, filepath)
if del_filepath is not None and filepath != del_filepath:
trainer.strategy.remove_checkpoint(del_filepath)

View File

@ -109,7 +109,6 @@ def test_model_checkpoint_score_and_ckpt(
def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.current_epoch, batch_idx]
self.log("val_log", log_value)
self.log("epoch", self.current_epoch, on_epoch=True)
return super().validation_step(batch, batch_idx)
def configure_optimizers(self):
@ -1086,7 +1085,7 @@ def test_hparams_type(tmpdir, use_omegaconf):
super().__init__()
self.save_hyperparameters(hparams)
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo")
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1)
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
@ -1281,3 +1280,24 @@ def test_last_global_step_saved():
trainer.callback_metrics = {"foo": 123}
model_checkpoint.save_checkpoint(trainer)
assert model_checkpoint._last_global_step_saved == 0
@pytest.mark.parametrize("every_n_epochs", (0, 5))
def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs):
"""Test that `save_last` ignores `every_n_epochs`."""
mc = ModelCheckpoint(every_n_epochs=every_n_epochs, save_last=True, save_top_k=0, save_on_train_epoch_end=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
callbacks=mc,
limit_train_batches=1,
limit_val_batches=0,
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
)
model = BoringModel()
with patch.object(trainer, "save_checkpoint") as save_mock:
trainer.fit(model)
assert mc.last_model_path # a "last" ckpt was saved
assert save_mock.call_count == trainer.max_epochs

View File

@ -159,8 +159,7 @@ def test_mlflow_logger_dirs_creation(tmpdir):
assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"}
class CustomModel(BoringModel):
def training_epoch_end(self, *args, **kwargs):
super().training_epoch_end(*args, **kwargs)
def on_train_epoch_end(self, *args, **kwargs):
self.log("epoch", self.current_epoch)
model = CustomModel()

View File

@ -106,7 +106,7 @@ class CustomClassificationModelDP(ClassificationModel):
def test_model_properties_fit_ckpt_path(tmpdir):
"""Test that properties like `current_epoch` and `global_step` in model and trainer are always the same."""
model = BoringModel()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True)
trainer_args = dict(
default_root_dir=tmpdir,
max_epochs=1,

View File

@ -403,6 +403,7 @@ def test_model_freeze_unfreeze():
assert param.requires_grad
@pytest.mark.xfail(reason="FIXME(@carmocca): this test wasn't running and is now broken")
@pytest.mark.parametrize("url_ckpt", [True, False])
def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Verify resuming from checkpoint runs the right number of epochs."""
@ -429,7 +430,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck
max_epochs=2,
limit_train_batches=0.65,
limit_val_batches=1,
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=-1)],
callbacks=[ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)],
default_root_dir=tmpdir,
val_check_interval=1.0,
enable_progress_bar=False,
@ -449,6 +450,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck
ip, port = tmpdir_server
checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints]
assert checkpoints
for ckpt in checkpoints:
next_model = TestModel()
state = pl_load(ckpt)