Optimize non-empty directory warning check in model checkpoint callback (#9615)
* pt1 dir empty check * clean imports * bring back resolve mkdir? * original doc * warningcache * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cp callback after resolve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move global_zero check outside warn fn Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * move global_zero check outside warn fn 2 Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
a3def9d228
commit
444b21dc3d
|
@ -24,8 +24,7 @@ import re
|
|||
import time
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional
|
||||
from weakref import proxy
|
||||
|
||||
import numpy as np
|
||||
|
@ -37,7 +36,7 @@ from pytorch_lightning.callbacks.base import Callback
|
|||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -203,7 +202,7 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
dirpath: Optional[Union[str, Path]] = None,
|
||||
dirpath: Optional[_PATH] = None,
|
||||
filename: Optional[str] = None,
|
||||
monitor: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
|
@ -267,6 +266,8 @@ class ModelCheckpoint(Callback):
|
|||
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""When pretrain routine starts we build the ckpt dir on the fly."""
|
||||
self.__resolve_ckpt_dir(trainer)
|
||||
if trainer.is_global_zero:
|
||||
self.__warn_if_dir_not_empty(self.dirpath)
|
||||
|
||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._last_time_checked = time.monotonic()
|
||||
|
@ -440,11 +441,8 @@ class ModelCheckpoint(Callback):
|
|||
" will duplicate the last checkpoint saved."
|
||||
)
|
||||
|
||||
def __init_ckpt_dir(self, dirpath: Optional[Union[str, Path]], filename: Optional[str]) -> None:
|
||||
self._fs = get_filesystem(str(dirpath) if dirpath else "")
|
||||
|
||||
if self.save_top_k != 0 and dirpath is not None and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
|
||||
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
|
||||
def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
|
||||
self._fs = get_filesystem(dirpath if dirpath else "")
|
||||
|
||||
if dirpath and self._fs.protocol == "file":
|
||||
dirpath = os.path.realpath(dirpath)
|
||||
|
@ -619,6 +617,10 @@ class ModelCheckpoint(Callback):
|
|||
if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint:
|
||||
self._fs.makedirs(self.dirpath, exist_ok=True)
|
||||
|
||||
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
|
||||
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
|
||||
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
|
||||
|
||||
def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
|
||||
metrics = trainer.callback_metrics
|
||||
|
||||
|
@ -735,7 +737,7 @@ class ModelCheckpoint(Callback):
|
|||
if del_filepath is not None and filepath != del_filepath:
|
||||
trainer.training_type_plugin.remove_checkpoint(del_filepath)
|
||||
|
||||
def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
|
||||
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
|
||||
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
|
||||
file."""
|
||||
best_k = {k: v.item() for k, v in self.best_k_models.items()}
|
||||
|
@ -744,7 +746,7 @@ class ModelCheckpoint(Callback):
|
|||
with self._fs.open(filepath, "w") as fp:
|
||||
yaml.dump(best_k, fp)
|
||||
|
||||
def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool:
|
||||
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
|
||||
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
|
||||
state to diverge between ranks."""
|
||||
exists = self._fs.exists(filepath)
|
||||
|
|
Loading…
Reference in New Issue