Add function to remove checkpoint to allow override for extended classes (#16067)
This commit is contained in:
parent
3b323c842d
commit
10cc6773e6
|
@ -649,7 +649,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
previous, self.last_model_path = self.last_model_path, filepath
|
||||
self._save_checkpoint(trainer, filepath)
|
||||
if previous and previous != filepath:
|
||||
trainer.strategy.remove_checkpoint(previous)
|
||||
self._remove_checkpoint(trainer, previous)
|
||||
|
||||
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
|
||||
assert self.monitor
|
||||
|
@ -668,7 +668,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
previous, self.best_model_path = self.best_model_path, filepath
|
||||
self._save_checkpoint(trainer, filepath)
|
||||
if self.save_top_k == 1 and previous and previous != filepath:
|
||||
trainer.strategy.remove_checkpoint(previous)
|
||||
self._remove_checkpoint(trainer, previous)
|
||||
|
||||
def _update_best_and_save(
|
||||
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
|
||||
|
@ -710,7 +710,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
self._save_checkpoint(trainer, filepath)
|
||||
|
||||
if del_filepath is not None and filepath != del_filepath:
|
||||
trainer.strategy.remove_checkpoint(del_filepath)
|
||||
self._remove_checkpoint(trainer, del_filepath)
|
||||
|
||||
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
|
||||
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
|
||||
|
@ -727,3 +727,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
state to diverge between ranks."""
|
||||
exists = self._fs.exists(filepath)
|
||||
return trainer.strategy.broadcast(exists)
|
||||
|
||||
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||
"""Calls the strategy to remove the checkpoint file."""
|
||||
trainer.strategy.remove_checkpoint(filepath)
|
||||
|
|
Loading…
Reference in New Issue