Optimize non-empty directory warning check in model checkpoint callback (#9615)

* pt1 dir empty check

* clean imports

* bring back resolve mkdir?

* original doc

* warningcache

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cp callback after resolve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move global_zero check outside warn fn

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* move global_zero check outside warn fn 2

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
jjenniferdai 2021-09-25 08:31:48 -07:00 committed by GitHub
parent a3def9d228
commit 444b21dc3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 11 deletions

View File

@ -24,8 +24,7 @@ import re
import time
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
from weakref import proxy
import numpy as np
@ -37,7 +36,7 @@ from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
log = logging.getLogger(__name__)
@ -203,7 +202,7 @@ class ModelCheckpoint(Callback):
def __init__(
self,
dirpath: Optional[Union[str, Path]] = None,
dirpath: Optional[_PATH] = None,
filename: Optional[str] = None,
monitor: Optional[str] = None,
verbose: bool = False,
@ -267,6 +266,8 @@ class ModelCheckpoint(Callback):
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we build the ckpt dir on the fly."""
self.__resolve_ckpt_dir(trainer)
if trainer.is_global_zero:
self.__warn_if_dir_not_empty(self.dirpath)
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._last_time_checked = time.monotonic()
@ -440,11 +441,8 @@ class ModelCheckpoint(Callback):
" will duplicate the last checkpoint saved."
)
def __init_ckpt_dir(self, dirpath: Optional[Union[str, Path]], filename: Optional[str]) -> None:
self._fs = get_filesystem(str(dirpath) if dirpath else "")
if self.save_top_k != 0 and dirpath is not None and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
self._fs = get_filesystem(dirpath if dirpath else "")
if dirpath and self._fs.protocol == "file":
dirpath = os.path.realpath(dirpath)
@ -619,6 +617,10 @@ class ModelCheckpoint(Callback):
if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint:
self._fs.makedirs(self.dirpath, exist_ok=True)
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
metrics = trainer.callback_metrics
@ -735,7 +737,7 @@ class ModelCheckpoint(Callback):
if del_filepath is not None and filepath != del_filepath:
trainer.training_type_plugin.remove_checkpoint(del_filepath)
def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
file."""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
@ -744,7 +746,7 @@ class ModelCheckpoint(Callback):
with self._fs.open(filepath, "w") as fp:
yaml.dump(best_k, fp)
def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool:
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
state to diverge between ranks."""
exists = self._fs.exists(filepath)