Type `trainer.connectors.checkpoint_connector` (#9419)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
5735b85147
commit
637f59f1d2
|
@ -65,6 +65,7 @@ module = [
|
|||
"pytorch_lightning.callbacks.pruning",
|
||||
"pytorch_lightning.loops.optimization.*",
|
||||
"pytorch_lightning.loops.evaluation_loop",
|
||||
"pytorch_lightning.trainer.connectors.checkpoint_connector",
|
||||
"pytorch_lightning.trainer.connectors.logger_connector.*",
|
||||
"pytorch_lightning.trainer.progress",
|
||||
"pytorch_lightning.tuner.auto_gpu_select",
|
||||
|
|
|
@ -35,7 +35,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|||
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple
|
||||
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
@ -668,7 +668,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
def _multi_device(self) -> bool:
|
||||
return self.num_processes > 1 or self.num_nodes > 1
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
|
||||
def save_checkpoint(self, checkpoint: Dict, filepath: _PATH) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
Args:
|
||||
|
@ -689,7 +689,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
|
||||
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
|
||||
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
|
||||
if self.load_full_weights and self.zero_stage_3:
|
||||
# Broadcast to ensure we load from the rank 0 checkpoint
|
||||
# This doesn't have to be the case when using deepspeed sharded checkpointing
|
||||
|
|
|
@ -20,6 +20,7 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl
|
|||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -62,10 +63,10 @@ class SingleTPUPlugin(SingleDevicePlugin):
|
|||
self.tpu_local_core_rank = xm.get_local_ordinal()
|
||||
self.tpu_global_core_rank = xm.get_ordinal()
|
||||
|
||||
def save(self, state_dict: Dict, path: str) -> None:
|
||||
def save(self, state_dict: Dict, path: _PATH) -> None:
|
||||
xm.save(state_dict, path)
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -35,7 +35,7 @@ from pytorch_lightning.utilities.data import has_len
|
|||
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_env_vars as xenv
|
||||
|
@ -207,7 +207,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.mp_queue.put(results)
|
||||
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue
|
||||
|
||||
def save(self, state_dict: Dict, path: str) -> None:
|
||||
def save(self, state_dict: Dict, path: _PATH) -> None:
|
||||
xm.save(state_dict, path)
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
|
@ -303,7 +303,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
|
||||
print()
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
@ -152,7 +151,7 @@ class TrainingTypePlugin(ABC):
|
|||
"""
|
||||
return self._results
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
|
||||
torch.cuda.empty_cache()
|
||||
return self.checkpoint_io.load_checkpoint(checkpoint_path)
|
||||
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
||||
|
@ -34,10 +34,10 @@ if _OMEGACONF_AVAILABLE:
|
|||
|
||||
|
||||
class CheckpointConnector:
|
||||
def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
||||
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
|
||||
self.trainer = trainer
|
||||
self.resume_checkpoint_path = resume_from_checkpoint
|
||||
self._loaded_checkpoint = {}
|
||||
self._loaded_checkpoint: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def hpc_resume_path(self) -> Optional[str]:
|
||||
|
@ -64,7 +64,7 @@ class CheckpointConnector:
|
|||
rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
|
||||
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
|
||||
|
||||
def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Optional[Dict[str, Any]]:
|
||||
def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
|
||||
loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)
|
||||
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
|
||||
raise ValueError(
|
||||
|
@ -89,7 +89,7 @@ class CheckpointConnector:
|
|||
# wait for all to catch up
|
||||
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")
|
||||
|
||||
def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None:
|
||||
def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
|
||||
"""Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and
|
||||
state-restore, in this priority:
|
||||
|
||||
|
@ -152,7 +152,7 @@ class CheckpointConnector:
|
|||
if isinstance(module, Metric):
|
||||
module.reset()
|
||||
|
||||
def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
|
||||
def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
|
||||
"""Restore only the model weights."""
|
||||
checkpoint = self._loaded_checkpoint
|
||||
if checkpoint_path is not None:
|
||||
|
@ -197,6 +197,7 @@ class CheckpointConnector:
|
|||
# crash if max_epochs is lower then the current epoch from the checkpoint
|
||||
if (
|
||||
FitLoop._is_max_limit_enabled(self.trainer.max_epochs)
|
||||
and self.trainer.max_epochs is not None
|
||||
and self.trainer.current_epoch > self.trainer.max_epochs
|
||||
):
|
||||
raise MisconfigurationException(
|
||||
|
@ -273,7 +274,7 @@ class CheckpointConnector:
|
|||
# PRIVATE OPS
|
||||
# ----------------------------------
|
||||
|
||||
def hpc_save(self, folderpath: str, logger):
|
||||
def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str:
|
||||
# make sure the checkpoint folder exists
|
||||
folderpath = str(folderpath) # because the tests pass a path object
|
||||
fs = get_filesystem(folderpath)
|
||||
|
@ -387,7 +388,7 @@ class CheckpointConnector:
|
|||
|
||||
return checkpoint
|
||||
|
||||
def hpc_load(self, checkpoint_path: str) -> None:
|
||||
def hpc_load(self, checkpoint_path: _PATH) -> None:
|
||||
"""Attempts to restore the full training and model state from a HPC checkpoint file.
|
||||
|
||||
.. deprecated:: v1.4 Will be removed in v1.6. Use :meth:`restore` instead.
|
||||
|
@ -398,7 +399,7 @@ class CheckpointConnector:
|
|||
)
|
||||
self.restore(checkpoint_path)
|
||||
|
||||
def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = "ckpt_") -> Optional[int]:
|
||||
def max_ckpt_version_in_folder(self, dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]:
|
||||
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
|
||||
|
||||
Args:
|
||||
|
@ -428,7 +429,7 @@ class CheckpointConnector:
|
|||
|
||||
return max(ckpt_vs)
|
||||
|
||||
def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
|
||||
def get_max_ckpt_path_from_folder(self, folder_path: _PATH) -> str:
|
||||
"""Get path of maximum-epoch checkpoint in the folder."""
|
||||
|
||||
max_suffix = self.max_ckpt_version_in_folder(folder_path)
|
||||
|
|
|
@ -48,7 +48,7 @@ from pytorch_lightning.utilities.argparse import (
|
|||
)
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeUnion
|
||||
|
||||
|
||||
class TrainerProperties(ABC):
|
||||
|
@ -182,11 +182,11 @@ class TrainerProperties(ABC):
|
|||
self.accelerator.optimizers = new_optims
|
||||
|
||||
@property
|
||||
def lr_schedulers(self) -> Optional[list]:
|
||||
def lr_schedulers(self) -> List[LRSchedulerTypeUnion]:
|
||||
return self.accelerator.lr_schedulers
|
||||
|
||||
@lr_schedulers.setter
|
||||
def lr_schedulers(self, new_schedulers: Optional[list]) -> None:
|
||||
def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None:
|
||||
self.accelerator.lr_schedulers = new_schedulers
|
||||
|
||||
@property
|
||||
|
|
|
@ -43,5 +43,7 @@ TRAIN_DATALOADERS = Union[
|
|||
Dict[str, Sequence[DataLoader]],
|
||||
]
|
||||
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
|
||||
# todo: improve LRSchedulerType naming/typing
|
||||
LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau)
|
||||
LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau]
|
||||
LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]]
|
||||
|
|
Loading…
Reference in New Issue