diff --git a/pyproject.toml b/pyproject.toml index b16992880d..d205af7f0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ module = [ "pytorch_lightning.callbacks.pruning", "pytorch_lightning.loops.optimization.*", "pytorch_lightning.loops.evaluation_loop", + "pytorch_lightning.trainer.connectors.checkpoint_connector", "pytorch_lightning.trainer.connectors.logger_connector.*", "pytorch_lightning.trainer.progress", "pytorch_lightning.tuner.auto_gpu_select", diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 3088efb9fe..77af58843c 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -35,7 +35,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE -from pytorch_lightning.utilities.types import LRSchedulerTypeTuple +from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache warning_cache = WarningCache() @@ -668,7 +668,7 @@ class DeepSpeedPlugin(DDPPlugin): def _multi_device(self) -> bool: return self.num_processes > 1 or self.num_nodes > 1 - def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: + def save_checkpoint(self, checkpoint: Dict, filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -689,7 +689,7 @@ class DeepSpeedPlugin(DDPPlugin): checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) - def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index b6f7d4000d..638cc089f9 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -20,6 +20,7 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _PATH if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -62,10 +63,10 @@ class SingleTPUPlugin(SingleDevicePlugin): self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - def save(self, state_dict: Dict, path: str) -> None: + def save(self, state_dict: Dict, path: _PATH) -> None: xm.save(state_dict, path) - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9140a995c8..43831fa2ac 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -35,7 +35,7 @@ from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv @@ -207,7 +207,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): self.mp_queue.put(results) self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue - def save(self, state_dict: Dict, path: str) -> None: + def save(self, state_dict: Dict, path: _PATH) -> None: xm.save(state_dict, path) def broadcast(self, obj: object, src: int = 0) -> object: @@ -303,7 +303,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: print() - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 39b47abfc6..13d6f93f5f 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from pathlib import Path from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union import torch @@ -152,7 +151,7 @@ class TrainingTypePlugin(ABC): """ return self._results - def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b48081481c..5376a81b65 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -14,13 +14,13 @@ import os import re -from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional import torch from torchmetrics import Metric import pytorch_lightning as pl +from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem @@ -34,10 +34,10 @@ if _OMEGACONF_AVAILABLE: class CheckpointConnector: - def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = None): + def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None: self.trainer = trainer self.resume_checkpoint_path = resume_from_checkpoint - self._loaded_checkpoint = {} + self._loaded_checkpoint: Dict[str, Any] = {} @property def hpc_resume_path(self) -> Optional[str]: @@ -64,7 +64,7 @@ class CheckpointConnector: rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path) - def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Optional[Dict[str, Any]]: + def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): raise ValueError( @@ -89,7 +89,7 @@ class CheckpointConnector: # wait for all to catch up self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") - def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: + def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore, in this priority: @@ -152,7 +152,7 @@ class CheckpointConnector: if isinstance(module, Metric): module.reset() - def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: + def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None: """Restore only the model weights.""" checkpoint = self._loaded_checkpoint if checkpoint_path is not None: @@ -197,6 +197,7 @@ class CheckpointConnector: # crash if max_epochs is lower then the current epoch from the checkpoint if ( FitLoop._is_max_limit_enabled(self.trainer.max_epochs) + and self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs ): raise MisconfigurationException( @@ -273,7 +274,7 @@ class CheckpointConnector: # PRIVATE OPS # ---------------------------------- - def hpc_save(self, folderpath: str, logger): + def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str: # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object fs = get_filesystem(folderpath) @@ -387,7 +388,7 @@ class CheckpointConnector: return checkpoint - def hpc_load(self, checkpoint_path: str) -> None: + def hpc_load(self, checkpoint_path: _PATH) -> None: """Attempts to restore the full training and model state from a HPC checkpoint file. .. deprecated:: v1.4 Will be removed in v1.6. Use :meth:`restore` instead. @@ -398,7 +399,7 @@ class CheckpointConnector: ) self.restore(checkpoint_path) - def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = "ckpt_") -> Optional[int]: + def max_ckpt_version_in_folder(self, dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. Args: @@ -428,7 +429,7 @@ class CheckpointConnector: return max(ckpt_vs) - def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: + def get_max_ckpt_path_from_folder(self, folder_path: _PATH) -> str: """Get path of maximum-epoch checkpoint in the folder.""" max_suffix = self.max_ckpt_version_in_folder(folder_path) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index e33eeb8b7c..2fca2b5a28 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -48,7 +48,7 @@ from pytorch_lightning.utilities.argparse import ( ) from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeUnion class TrainerProperties(ABC): @@ -182,11 +182,11 @@ class TrainerProperties(ABC): self.accelerator.optimizers = new_optims @property - def lr_schedulers(self) -> Optional[list]: + def lr_schedulers(self) -> List[LRSchedulerTypeUnion]: return self.accelerator.lr_schedulers @lr_schedulers.setter - def lr_schedulers(self, new_schedulers: Optional[list]) -> None: + def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None: self.accelerator.lr_schedulers = new_schedulers @property diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 7cb5577346..86c9028a5f 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -43,5 +43,7 @@ TRAIN_DATALOADERS = Union[ Dict[str, Sequence[DataLoader]], ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] +# todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau] LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]]