From 444b21dc3d46f3e186529fafda5b0751daacdf6f Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Sat, 25 Sep 2021 08:31:48 -0700 Subject: [PATCH] Optimize non-empty directory warning check in model checkpoint callback (#9615) * pt1 dir empty check * clean imports * bring back resolve mkdir? * original doc * warningcache * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cp callback after resolve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move global_zero check outside warn fn Co-authored-by: ananthsub * move global_zero check outside warn fn 2 Co-authored-by: ananthsub * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ananthsub --- .../callbacks/model_checkpoint.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 42cd078d21..cff2116446 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -24,8 +24,7 @@ import re import time from copy import deepcopy from datetime import timedelta -from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional from weakref import proxy import numpy as np @@ -37,7 +36,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT +from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache log = logging.getLogger(__name__) @@ -203,7 +202,7 @@ class ModelCheckpoint(Callback): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, + dirpath: Optional[_PATH] = None, filename: Optional[str] = None, monitor: Optional[str] = None, verbose: bool = False, @@ -267,6 +266,8 @@ class ModelCheckpoint(Callback): def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """When pretrain routine starts we build the ckpt dir on the fly.""" self.__resolve_ckpt_dir(trainer) + if trainer.is_global_zero: + self.__warn_if_dir_not_empty(self.dirpath) def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._last_time_checked = time.monotonic() @@ -440,11 +441,8 @@ class ModelCheckpoint(Callback): " will duplicate the last checkpoint saved." ) - def __init_ckpt_dir(self, dirpath: Optional[Union[str, Path]], filename: Optional[str]) -> None: - self._fs = get_filesystem(str(dirpath) if dirpath else "") - - if self.save_top_k != 0 and dirpath is not None and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: - rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") + def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: + self._fs = get_filesystem(dirpath if dirpath else "") if dirpath and self._fs.protocol == "file": dirpath = os.path.realpath(dirpath) @@ -619,6 +617,10 @@ class ModelCheckpoint(Callback): if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint: self._fs.makedirs(self.dirpath, exist_ok=True) + def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: + if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: + rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") + def _validate_monitor_key(self, trainer: "pl.Trainer") -> None: metrics = trainer.callback_metrics @@ -735,7 +737,7 @@ class ModelCheckpoint(Callback): if del_filepath is not None and filepath != del_filepath: trainer.training_type_plugin.remove_checkpoint(del_filepath) - def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None: + def to_yaml(self, filepath: Optional[_PATH] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML file.""" best_k = {k: v.item() for k, v in self.best_k_models.items()} @@ -744,7 +746,7 @@ class ModelCheckpoint(Callback): with self._fs.open(filepath, "w") as fp: yaml.dump(best_k, fp) - def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool: + def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool: """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.""" exists = self._fs.exists(filepath)