Add function to remove checkpoint to allow override for extended classes (#16067)

This commit is contained in:
Sean Naren 2022-12-15 16:20:35 +00:00 committed by GitHub
parent 3b323c842d
commit 10cc6773e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 3 deletions

View File

@ -649,7 +649,7 @@ class ModelCheckpoint(Checkpoint):
previous, self.last_model_path = self.last_model_path, filepath
self._save_checkpoint(trainer, filepath)
if previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
self._remove_checkpoint(trainer, previous)
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
assert self.monitor
@ -668,7 +668,7 @@ class ModelCheckpoint(Checkpoint):
previous, self.best_model_path = self.best_model_path, filepath
self._save_checkpoint(trainer, filepath)
if self.save_top_k == 1 and previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
self._remove_checkpoint(trainer, previous)
def _update_best_and_save(
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
@ -710,7 +710,7 @@ class ModelCheckpoint(Checkpoint):
self._save_checkpoint(trainer, filepath)
if del_filepath is not None and filepath != del_filepath:
trainer.strategy.remove_checkpoint(del_filepath)
self._remove_checkpoint(trainer, del_filepath)
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
@ -727,3 +727,7 @@ class ModelCheckpoint(Checkpoint):
state to diverge between ranks."""
exists = self._fs.exists(filepath)
return trainer.strategy.broadcast(exists)
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""Calls the strategy to remove the checkpoint file."""
trainer.strategy.remove_checkpoint(filepath)