From 10cc6773e66515dcbda126c442cc4d9a9bbf669d Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 15 Dec 2022 16:20:35 +0000 Subject: [PATCH] Add function to remove checkpoint to allow override for extended classes (#16067) --- src/pytorch_lightning/callbacks/model_checkpoint.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 0a7b400bb9..d7227c78f4 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -649,7 +649,7 @@ class ModelCheckpoint(Checkpoint): previous, self.last_model_path = self.last_model_path, filepath self._save_checkpoint(trainer, filepath) if previous and previous != filepath: - trainer.strategy.remove_checkpoint(previous) + self._remove_checkpoint(trainer, previous) def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: assert self.monitor @@ -668,7 +668,7 @@ class ModelCheckpoint(Checkpoint): previous, self.best_model_path = self.best_model_path, filepath self._save_checkpoint(trainer, filepath) if self.save_top_k == 1 and previous and previous != filepath: - trainer.strategy.remove_checkpoint(previous) + self._remove_checkpoint(trainer, previous) def _update_best_and_save( self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] @@ -710,7 +710,7 @@ class ModelCheckpoint(Checkpoint): self._save_checkpoint(trainer, filepath) if del_filepath is not None and filepath != del_filepath: - trainer.strategy.remove_checkpoint(del_filepath) + self._remove_checkpoint(trainer, del_filepath) def to_yaml(self, filepath: Optional[_PATH] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML @@ -727,3 +727,7 @@ class ModelCheckpoint(Checkpoint): state to diverge between ranks.""" exists = self._fs.exists(filepath) return trainer.strategy.broadcast(exists) + + def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + """Calls the strategy to remove the checkpoint file.""" + trainer.strategy.remove_checkpoint(filepath)