`ModelCheckpoint`'s `save_last` now ignores `every_n_epochs` (#12418)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
dcc973e019
commit
71e0ddb62f
12
CHANGELOG.md
12
CHANGELOG.md
|
@ -382,6 +382,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Move `Strategy.process_dataloader` function call from `fit/evaluation/predict_loop.py` to `data_connector.py` ([#12251](https://github.com/PyTorchLightning/pytorch-lightning/pull/12251))
|
||||
|
||||
|
||||
- `ModelCheckpoint(save_last=True, every_n_epochs=N)` now saves a "last" checkpoint every epoch (disregarding `every_n_epochs`) instead of only once at the end of training ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
|
||||
|
||||
|
||||
- The strategies that support `sync_batchnorm` now only apply it when fitting ([#11919](https://github.com/PyTorchLightning/pytorch-lightning/pull/11919))
|
||||
|
||||
|
||||
|
@ -861,6 +864,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249))
|
||||
|
||||
|
||||
- Fixed bug where the global step tracked by `ModelCheckpoint` was still set even if no checkpoint was saved ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
|
||||
-
|
||||
|
||||
- Fixed bug where `ModelCheckpoint` was overriding the `epoch` and `step` logged values ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
|
||||
|
||||
|
||||
- Fixed bug where monitoring the default `epoch` and `step` values with `ModelCheckpoint` would fail ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
|
||||
|
||||
|
||||
- Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267))
|
||||
|
||||
|
||||
|
|
|
@ -100,8 +100,8 @@ class ModelCheckpoint(Callback):
|
|||
based on either the maximization or the minimization of the monitored quantity.
|
||||
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
|
||||
auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
|
||||
For example, ``filename='checkpoint_{epoch:02d}-{acc:02d}`` with epoch 1 and acc 80 will resolve to
|
||||
``checkpoint_epoch=01-acc=80.ckp``. Is useful to set it to ``False`` when metric names contain ``/``
|
||||
For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
|
||||
to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
|
||||
as this will result in extra folders.
|
||||
save_weights_only: if ``True``, then only the model's weights will be
|
||||
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
|
||||
|
@ -116,7 +116,8 @@ class ModelCheckpoint(Callback):
|
|||
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
||||
every_n_epochs: Number of epochs between checkpoints.
|
||||
This value must be ``None`` or non-negative.
|
||||
To disable saving after each epoch, set ``every_n_epochs = 0``.
|
||||
To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
|
||||
This argument does not impact the saving of ``save_last=True`` checkpoints.
|
||||
If all of ``every_n_epochs``, ``every_n_train_steps`` and
|
||||
``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
|
||||
(equivalent to ``every_n_epochs = 1``).
|
||||
|
@ -295,28 +296,25 @@ class ModelCheckpoint(Callback):
|
|||
if not skip_time:
|
||||
self._last_time_checked = now
|
||||
|
||||
self.save_checkpoint(trainer)
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Save a checkpoint at the end of the training epoch."""
|
||||
if (
|
||||
not self._should_skip_saving_checkpoint(trainer)
|
||||
and self._save_on_train_epoch_end
|
||||
and self._every_n_epochs > 0
|
||||
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
||||
):
|
||||
self.save_checkpoint(trainer)
|
||||
if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end:
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Save a checkpoint at the end of the validation stage."""
|
||||
if (
|
||||
self._should_skip_saving_checkpoint(trainer)
|
||||
or self._save_on_train_epoch_end
|
||||
or self._every_n_epochs < 1
|
||||
or (trainer.current_epoch + 1) % self._every_n_epochs != 0
|
||||
):
|
||||
return
|
||||
self.save_checkpoint(trainer)
|
||||
if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end:
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
def on_save_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
||||
|
@ -352,26 +350,41 @@ class ModelCheckpoint(Callback):
|
|||
self.last_model_path = callback_state.get("last_model_path", self.last_model_path)
|
||||
self.best_model_path = callback_state["best_model_path"]
|
||||
|
||||
def save_checkpoint(self, trainer: "pl.Trainer") -> None:
|
||||
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
|
||||
"""Performs the main logic around saving a checkpoint.
|
||||
|
||||
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
|
||||
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
|
||||
"""
|
||||
self._validate_monitor_key(trainer)
|
||||
|
||||
# what can be monitored
|
||||
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=trainer.global_step)
|
||||
|
||||
# callback supports multiple simultaneous modes
|
||||
# here we call each mode sequentially
|
||||
# Mode 1: save the top k checkpoints
|
||||
self._save_top_k_checkpoint(trainer, monitor_candidates)
|
||||
# Mode 2: save monitor=None checkpoints
|
||||
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
|
||||
# Mode 3: save last checkpoints
|
||||
# TODO: unused method. deprecate it
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
if self.save_top_k == 0:
|
||||
return
|
||||
|
||||
# validate metric
|
||||
if self.monitor is not None:
|
||||
if self.monitor not in monitor_candidates:
|
||||
m = (
|
||||
f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"
|
||||
f" metrics: {list(monitor_candidates)}."
|
||||
f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
|
||||
)
|
||||
if trainer.fit_loop.epoch_loop.val_loop._has_run:
|
||||
raise MisconfigurationException(m)
|
||||
warning_cache.warn(m)
|
||||
self._save_monitor_checkpoint(trainer, monitor_candidates)
|
||||
else:
|
||||
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||
trainer.save_checkpoint(filepath, self.save_weights_only)
|
||||
|
||||
self._last_global_step_saved = trainer.global_step
|
||||
|
||||
# notify loggers
|
||||
if trainer.is_global_zero:
|
||||
for logger in trainer.loggers:
|
||||
|
@ -594,21 +607,6 @@ class ModelCheckpoint(Callback):
|
|||
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
|
||||
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
|
||||
|
||||
def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
|
||||
metrics = trainer.callback_metrics
|
||||
|
||||
# validate metric
|
||||
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
|
||||
m = (
|
||||
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
|
||||
f" {list(metrics.keys())}. "
|
||||
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
|
||||
)
|
||||
if not trainer.fit_loop.epoch_loop.val_loop._has_run:
|
||||
warning_cache.warn(m)
|
||||
else:
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
def _get_metric_interpolated_filepath_name(
|
||||
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
|
||||
) -> str:
|
||||
|
@ -621,51 +619,46 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
return filepath
|
||||
|
||||
def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> Dict[str, _METRIC]:
|
||||
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
|
||||
monitor_candidates = deepcopy(trainer.callback_metrics)
|
||||
monitor_candidates.update(epoch=epoch, step=step)
|
||||
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
|
||||
# or does not exist we overwrite it as it's likely an error
|
||||
epoch = monitor_candidates.get("epoch")
|
||||
monitor_candidates["epoch"] = (
|
||||
epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch)
|
||||
)
|
||||
step = monitor_candidates.get("step")
|
||||
monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step)
|
||||
return monitor_candidates
|
||||
|
||||
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
if not self.save_last:
|
||||
return
|
||||
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
|
||||
|
||||
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
|
||||
# set the last model path before saving because it will be part of the state.
|
||||
previous, self.last_model_path = self.last_model_path, filepath
|
||||
trainer.save_checkpoint(filepath, self.save_weights_only)
|
||||
self._save_checkpoint(trainer, filepath)
|
||||
if previous and previous != filepath:
|
||||
trainer.strategy.remove_checkpoint(previous)
|
||||
|
||||
def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
if self.monitor is None or self.save_top_k == 0:
|
||||
return
|
||||
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
|
||||
|
||||
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
current = monitor_candidates.get(self.monitor)
|
||||
if self.check_monitor_top_k(trainer, current):
|
||||
self._update_best_and_save(current, trainer, monitor_candidates)
|
||||
elif self.verbose:
|
||||
epoch = monitor_candidates.get("epoch")
|
||||
step = monitor_candidates.get("step")
|
||||
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor} was not in top {self.save_top_k}")
|
||||
epoch = monitor_candidates["epoch"]
|
||||
step = monitor_candidates["step"]
|
||||
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
|
||||
|
||||
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
if self.monitor is not None or self.save_top_k == 0:
|
||||
return
|
||||
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
|
||||
|
||||
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
|
||||
# set the best model path before saving because it will be part of the state.
|
||||
previous, self.best_model_path = self.best_model_path, filepath
|
||||
trainer.save_checkpoint(filepath, self.save_weights_only)
|
||||
self._save_checkpoint(trainer, filepath)
|
||||
if self.save_top_k == 1 and previous and previous != filepath:
|
||||
trainer.strategy.remove_checkpoint(previous)
|
||||
|
||||
def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool:
|
||||
return self.monitor in metrics or len(metrics) == 0
|
||||
|
||||
def _update_best_and_save(
|
||||
self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
|
||||
) -> None:
|
||||
|
@ -697,13 +690,13 @@ class ModelCheckpoint(Callback):
|
|||
self.best_model_score = self.best_k_models[self.best_model_path]
|
||||
|
||||
if self.verbose:
|
||||
epoch = monitor_candidates.get("epoch")
|
||||
step = monitor_candidates.get("step")
|
||||
epoch = monitor_candidates["epoch"]
|
||||
step = monitor_candidates["step"]
|
||||
rank_zero_info(
|
||||
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
|
||||
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
|
||||
f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}"
|
||||
f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}"
|
||||
)
|
||||
trainer.save_checkpoint(filepath, self.save_weights_only)
|
||||
self._save_checkpoint(trainer, filepath)
|
||||
|
||||
if del_filepath is not None and filepath != del_filepath:
|
||||
trainer.strategy.remove_checkpoint(del_filepath)
|
||||
|
|
|
@ -109,7 +109,6 @@ def test_model_checkpoint_score_and_ckpt(
|
|||
def validation_step(self, batch, batch_idx):
|
||||
log_value = self.val_logs[self.current_epoch, batch_idx]
|
||||
self.log("val_log", log_value)
|
||||
self.log("epoch", self.current_epoch, on_epoch=True)
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
def configure_optimizers(self):
|
||||
|
@ -1086,7 +1085,7 @@ def test_hparams_type(tmpdir, use_omegaconf):
|
|||
super().__init__()
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo")
|
||||
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1)
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -1281,3 +1280,24 @@ def test_last_global_step_saved():
|
|||
trainer.callback_metrics = {"foo": 123}
|
||||
model_checkpoint.save_checkpoint(trainer)
|
||||
assert model_checkpoint._last_global_step_saved == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("every_n_epochs", (0, 5))
|
||||
def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs):
|
||||
"""Test that `save_last` ignores `every_n_epochs`."""
|
||||
mc = ModelCheckpoint(every_n_epochs=every_n_epochs, save_last=True, save_top_k=0, save_on_train_epoch_end=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
callbacks=mc,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=0,
|
||||
enable_progress_bar=False,
|
||||
enable_model_summary=False,
|
||||
logger=False,
|
||||
)
|
||||
model = BoringModel()
|
||||
with patch.object(trainer, "save_checkpoint") as save_mock:
|
||||
trainer.fit(model)
|
||||
assert mc.last_model_path # a "last" ckpt was saved
|
||||
assert save_mock.call_count == trainer.max_epochs
|
||||
|
|
|
@ -159,8 +159,7 @@ def test_mlflow_logger_dirs_creation(tmpdir):
|
|||
assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"}
|
||||
|
||||
class CustomModel(BoringModel):
|
||||
def training_epoch_end(self, *args, **kwargs):
|
||||
super().training_epoch_end(*args, **kwargs)
|
||||
def on_train_epoch_end(self, *args, **kwargs):
|
||||
self.log("epoch", self.current_epoch)
|
||||
|
||||
model = CustomModel()
|
||||
|
|
|
@ -106,7 +106,7 @@ class CustomClassificationModelDP(ClassificationModel):
|
|||
def test_model_properties_fit_ckpt_path(tmpdir):
|
||||
"""Test that properties like `current_epoch` and `global_step` in model and trainer are always the same."""
|
||||
model = BoringModel()
|
||||
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
|
||||
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True)
|
||||
trainer_args = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
|
|
|
@ -403,6 +403,7 @@ def test_model_freeze_unfreeze():
|
|||
assert param.requires_grad
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="FIXME(@carmocca): this test wasn't running and is now broken")
|
||||
@pytest.mark.parametrize("url_ckpt", [True, False])
|
||||
def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
||||
"""Verify resuming from checkpoint runs the right number of epochs."""
|
||||
|
@ -429,7 +430,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck
|
|||
max_epochs=2,
|
||||
limit_train_batches=0.65,
|
||||
limit_val_batches=1,
|
||||
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=-1)],
|
||||
callbacks=[ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)],
|
||||
default_root_dir=tmpdir,
|
||||
val_check_interval=1.0,
|
||||
enable_progress_bar=False,
|
||||
|
@ -449,6 +450,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck
|
|||
ip, port = tmpdir_server
|
||||
checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints]
|
||||
|
||||
assert checkpoints
|
||||
for ckpt in checkpoints:
|
||||
next_model = TestModel()
|
||||
state = pl_load(ckpt)
|
||||
|
|
Loading…
Reference in New Issue