Type `trainer.connectors.checkpoint_connector` (#9419)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
jjenniferdai 2021-09-14 18:02:19 -07:00 committed by GitHub
parent 5735b85147
commit 637f59f1d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 28 additions and 24 deletions

View File

@ -65,6 +65,7 @@ module = [
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.loops.optimization.*",
"pytorch_lightning.loops.evaluation_loop",
"pytorch_lightning.trainer.connectors.checkpoint_connector",
"pytorch_lightning.trainer.connectors.logger_connector.*",
"pytorch_lightning.trainer.progress",
"pytorch_lightning.tuner.auto_gpu_select",

View File

@ -35,7 +35,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
warning_cache = WarningCache()
@ -668,7 +668,7 @@ class DeepSpeedPlugin(DDPPlugin):
def _multi_device(self) -> bool:
return self.num_processes > 1 or self.num_nodes > 1
def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict, filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
@ -689,7 +689,7 @@ class DeepSpeedPlugin(DDPPlugin):
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
if self.load_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing

View File

@ -20,6 +20,7 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PATH
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
@ -62,10 +63,10 @@ class SingleTPUPlugin(SingleDevicePlugin):
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
def save(self, state_dict: Dict, path: str) -> None:
def save(self, state_dict: Dict, path: _PATH) -> None:
xm.save(state_dict, path)
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:

View File

@ -35,7 +35,7 @@ from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
@ -207,7 +207,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue
def save(self, state_dict: Dict, path: str) -> None:
def save(self, state_dict: Dict, path: _PATH) -> None:
xm.save(state_dict, path)
def broadcast(self, obj: object, src: int = 0) -> object:
@ -303,7 +303,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print()
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:

View File

@ -13,7 +13,6 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union
import torch
@ -152,7 +151,7 @@ class TrainingTypePlugin(ABC):
"""
return self._results
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)

View File

@ -14,13 +14,13 @@
import os
import re
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
import torch
from torchmetrics import Metric
import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
@ -34,10 +34,10 @@ if _OMEGACONF_AVAILABLE:
class CheckpointConnector:
def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = None):
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
self.trainer = trainer
self.resume_checkpoint_path = resume_from_checkpoint
self._loaded_checkpoint = {}
self._loaded_checkpoint: Dict[str, Any] = {}
@property
def hpc_resume_path(self) -> Optional[str]:
@ -64,7 +64,7 @@ class CheckpointConnector:
rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Optional[Dict[str, Any]]:
def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
raise ValueError(
@ -89,7 +89,7 @@ class CheckpointConnector:
# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")
def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None:
def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and
state-restore, in this priority:
@ -152,7 +152,7 @@ class CheckpointConnector:
if isinstance(module, Metric):
module.reset()
def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
if checkpoint_path is not None:
@ -197,6 +197,7 @@ class CheckpointConnector:
# crash if max_epochs is lower then the current epoch from the checkpoint
if (
FitLoop._is_max_limit_enabled(self.trainer.max_epochs)
and self.trainer.max_epochs is not None
and self.trainer.current_epoch > self.trainer.max_epochs
):
raise MisconfigurationException(
@ -273,7 +274,7 @@ class CheckpointConnector:
# PRIVATE OPS
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str:
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
fs = get_filesystem(folderpath)
@ -387,7 +388,7 @@ class CheckpointConnector:
return checkpoint
def hpc_load(self, checkpoint_path: str) -> None:
def hpc_load(self, checkpoint_path: _PATH) -> None:
"""Attempts to restore the full training and model state from a HPC checkpoint file.
.. deprecated:: v1.4 Will be removed in v1.6. Use :meth:`restore` instead.
@ -398,7 +399,7 @@ class CheckpointConnector:
)
self.restore(checkpoint_path)
def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = "ckpt_") -> Optional[int]:
def max_ckpt_version_in_folder(self, dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Args:
@ -428,7 +429,7 @@ class CheckpointConnector:
return max(ckpt_vs)
def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
def get_max_ckpt_path_from_folder(self, folder_path: _PATH) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""
max_suffix = self.max_ckpt_version_in_folder(folder_path)

View File

@ -48,7 +48,7 @@ from pytorch_lightning.utilities.argparse import (
)
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeUnion
class TrainerProperties(ABC):
@ -182,11 +182,11 @@ class TrainerProperties(ABC):
self.accelerator.optimizers = new_optims
@property
def lr_schedulers(self) -> Optional[list]:
def lr_schedulers(self) -> List[LRSchedulerTypeUnion]:
return self.accelerator.lr_schedulers
@lr_schedulers.setter
def lr_schedulers(self, new_schedulers: Optional[list]) -> None:
def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None:
self.accelerator.lr_schedulers = new_schedulers
@property

View File

@ -43,5 +43,7 @@ TRAIN_DATALOADERS = Union[
Dict[str, Sequence[DataLoader]],
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau]
LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]]