diff --git a/CHANGELOG.md b/CHANGELOG.md index ce7d562d35..2462ffbdbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360)) +- Removed `TrainerProperties` mixin and moved property definitions directly into `Trainer` ([#9495](https://github.com/PyTorchLightning/pytorch-lightning/pull/9495)) + + - Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4846b7c117..c5516fb4d4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -12,27 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. """Trainer to automate the training.""" +import inspect import logging import os import traceback import warnings +from argparse import ArgumentParser, Namespace from datetime import timedelta from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union from weakref import proxy import torch +from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator, IPUAccelerator -from pytorch_lightning.callbacks import Callback +from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop +from pytorch_lightning.loggers.base import LoggerCollection +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop -from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.plugins import DDPSpawnPlugin, PLUGIN_INPUT +from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -50,6 +56,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector @@ -58,8 +65,7 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus +from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner @@ -68,18 +74,33 @@ from pytorch_lightning.utilities import ( _TPU_AVAILABLE, device_parser, DeviceType, + DistributedType, parsing, rank_zero_deprecation, rank_zero_info, rank_zero_warn, ) +from pytorch_lightning.utilities.argparse import ( + add_argparse_args, + from_argparse_args, + parse_argparser, + parse_env_variables, +) +from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS +from pytorch_lightning.utilities.types import ( + _EVALUATE_OUTPUT, + _PATH, + _PREDICT_OUTPUT, + EVAL_DATALOADERS, + LRSchedulerTypeUnion, + TRAIN_DATALOADERS, +) log = logging.getLogger(__name__) # warnings to ignore in trainer @@ -89,13 +110,15 @@ warnings.filterwarnings( class Trainer( - TrainerProperties, TrainerCallbackHookMixin, TrainerModelHooksMixin, TrainerOptimizersMixin, TrainerDataLoadingMixin, DeprecatedTrainerAttributes, ): + # Needed because of LightningOptimizer + _lightning_optimizers = None + @_defaults_from_env_vars def __init__( self, @@ -1393,3 +1416,579 @@ class Trainer( # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") self.save_checkpoint(file_path) + + """ + Accelerator properties + """ + + @property + def accelerator(self) -> Accelerator: + return self.accelerator_connector.accelerator + + @property + def distributed_backend(self) -> Optional[str]: + # for backward compatibility + return self.accelerator_connector.distributed_backend + + @property + def training_type_plugin(self) -> TrainingTypePlugin: + return self.accelerator.training_type_plugin + + @property + def precision_plugin(self) -> PrecisionPlugin: + return self.accelerator.precision_plugin + + @property + def global_rank(self) -> int: + return self.accelerator.training_type_plugin.global_rank + + @property + def local_rank(self) -> int: + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "local_rank", 0) + + @property + def node_rank(self) -> int: + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "node_rank", 0) + + @property + def world_size(self) -> int: + # some training types define a world size + return getattr(self.accelerator.training_type_plugin, "world_size", 1) + + @property + def should_rank_save_checkpoint(self) -> bool: + return self.accelerator.training_type_plugin.should_rank_save_checkpoint + + @property + def _distrib_type(self) -> DistributedType: + return self.accelerator_connector._distrib_type + + @property + def _device_type(self) -> DeviceType: + return self.accelerator_connector._device_type + + @property + def num_nodes(self) -> int: + return self.accelerator_connector.num_nodes + + @property + def num_processes(self) -> int: + return self.accelerator_connector.num_processes + + @property + def root_gpu(self) -> Optional[int]: + return self.accelerator_connector.root_gpu + + @property + def tpu_cores(self) -> int: + return self.accelerator_connector.tpu_cores + + @property + def ipus(self) -> int: + return self.accelerator_connector.num_ipus + + @property + def num_gpus(self) -> int: + return self.accelerator_connector.num_gpus + + @property + def devices(self) -> Optional[Union[List[int], str, int]]: + return self.accelerator_connector.devices + + @property + def data_parallel_device_ids(self) -> Optional[List[int]]: + return self.accelerator_connector.parallel_device_ids + + @property + def lightning_module(self) -> "pl.LightningModule": + return self.accelerator.lightning_module + + @property + def optimizers(self) -> List[Optimizer]: + return self.accelerator.optimizers + + @optimizers.setter + def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: + # Necessary to rewrap optimizers to lightning + # They will be re-created when accessing + # the `lightning_optimizers` trainer property + self._lightning_optimizers = None + + self.accelerator.optimizers = new_optims + + @property + def lr_schedulers(self) -> List[LRSchedulerTypeUnion]: + return self.accelerator.lr_schedulers + + @lr_schedulers.setter + def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None: + self.accelerator.lr_schedulers = new_schedulers + + @property + def optimizer_frequencies(self) -> list: + return self.accelerator.optimizer_frequencies + + @optimizer_frequencies.setter + def optimizer_frequencies(self, new_freqs: list) -> None: + self.accelerator.optimizer_frequencies = new_freqs + + @property + def amp_backend(self) -> Optional[str]: + return self.accelerator.amp_backend + + @property + def precision(self) -> Union[str, int]: + return self.accelerator.precision + + @property + def scaler(self): + return self.accelerator.scaler + + @property + def gpus(self) -> Optional[Union[List[int], str, int]]: + return self.accelerator_connector.gpus + + @property + def model(self) -> torch.nn.Module: + """The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. + + To access the pure LightningModule, use + :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. + """ + return self.accelerator.model + + @model.setter + def model(self, model: torch.nn.Module) -> None: + """Setter for the model, pass-through to accelerator and plugin where the model reference is stored. Used + by the Tuner to reset the state of Trainer and Accelerator. + + Args: + model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending + on the backend. + """ + self.accelerator.model = model + + """ + General properties + """ + + @property + def log_dir(self) -> Optional[str]: + if self.logger is None: + dirpath = self.default_root_dir + elif isinstance(self.logger, TensorBoardLogger): + dirpath = self.logger.log_dir + elif isinstance(self.logger, LoggerCollection): + dirpath = self.default_root_dir + else: + dirpath = self.logger.save_dir + + dirpath = self.accelerator.broadcast(dirpath) + return dirpath + + @property + def use_amp(self) -> bool: + return self.precision == 16 + + @property + def is_global_zero(self) -> bool: + return self.global_rank == 0 + + @property + def slurm_job_id(self) -> Optional[int]: + job_id = os.environ.get("SLURM_JOB_ID") + if job_id: + try: + job_id = int(job_id) + except ValueError: + job_id = None + + # in interactive mode, don't make logs use the same job id + in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash" + if in_slurm_interactive_mode: + job_id = None + return job_id + + @property + def lightning_optimizers(self) -> List[LightningOptimizer]: + if self._lightning_optimizers is None: + self.convert_to_lightning_optimizers() + return self._lightning_optimizers + + @property + def distributed_sampler_kwargs(self) -> Optional[dict]: + if isinstance(self.training_type_plugin, ParallelPlugin): + return self.training_type_plugin.distributed_sampler_kwargs + + @property + def data_parallel(self) -> bool: + return self._distrib_type in ( + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.DDP2, + ) + + @property + def progress_bar_callback(self) -> Optional[ProgressBarBase]: + return self._progress_bar_callback + + @property + def progress_bar_dict(self) -> dict: + """Read-only for progress bar metrics.""" + rank_zero_deprecation( + "`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7." + " Use `ProgressBarBase.get_metrics` instead." + ) + ref_model = self.lightning_module + ref_model = cast(pl.LightningModule, ref_model) + if self.progress_bar_callback: + return self.progress_bar_callback.get_metrics(self, ref_model) + return self.progress_bar_metrics + + @property + def _should_reload_dl_epoch(self) -> bool: + """Check if dataloader should be reloaded in the current epoch.""" + n_epochs = self.reload_dataloaders_every_n_epochs + return n_epochs and (not self.current_epoch % n_epochs) + + @property + def disable_validation(self) -> bool: + """Check if validation is disabled during training.""" + rank_zero_deprecation( + "`trainer.disable_validation` is deprecated in v1.4 and will be removed in v1.6." + " Use `not trainer.enable_validation` instead." + ) + return not self.enable_validation + + @property + def enable_validation(self) -> bool: + """Check if we should run validation during training.""" + model_ref = self.lightning_module + val_loop_enabled = is_overridden("validation_step", model_ref) and self.limit_val_batches > 0 + return val_loop_enabled + + @property + def default_root_dir(self) -> str: + """The default location to save artifacts of loggers, checkpoints etc. + + It is used as a fallback if logger or checkpoint callback do not define specific save paths. + """ + if get_filesystem(self._default_root_dir).protocol == "file": + return os.path.normpath(self._default_root_dir) + return self._default_root_dir + + @property + def weights_save_path(self) -> str: + """ + The default root location to save weights (checkpoints), e.g., when the + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. + """ + if get_filesystem(self._weights_save_path).protocol == "file": + return os.path.normpath(self._weights_save_path) + return self._weights_save_path + + @property + def early_stopping_callback(self) -> Optional[EarlyStopping]: + """The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback in the + Trainer.callbacks list, or ``None`` if it doesn't exist.""" + callbacks = self.early_stopping_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def early_stopping_callbacks(self) -> List[EarlyStopping]: + """A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` found in + the Trainer.callbacks list.""" + return [c for c in self.callbacks if isinstance(c, EarlyStopping)] + + @property + def prediction_writer_callbacks(self) -> List[BasePredictionWriter]: + """A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter` + found in the Trainer.callbacks list.""" + return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)] + + @property + def checkpoint_callback(self) -> Optional[ModelCheckpoint]: + """The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the + Trainer.callbacks list, or ``None`` if it doesn't exist.""" + callbacks = self.checkpoint_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def checkpoint_callbacks(self) -> List[ModelCheckpoint]: + """A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found + in the Trainer.callbacks list.""" + return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + + @property + def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: + return self.checkpoint_connector.resume_checkpoint_path + + def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: + self.checkpoint_connector.save_checkpoint(filepath, weights_only) + + """ + Parsing properties + """ + + @classmethod + def default_attributes(cls) -> dict: + init_signature = inspect.signature(cls) + return {k: v.default for k, v in init_signature.parameters.items()} + + @classmethod + def get_deprecated_arg_names(cls) -> List: + """Returns a list with deprecated Trainer arguments.""" + depr_arg_names = [] + for name, val in cls.__dict__.items(): + if name.startswith("DEPRECATED") and isinstance(val, (tuple, list)): + depr_arg_names.extend(val) + return depr_arg_names + + @classmethod + def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs) -> Any: + return from_argparse_args(cls, args, **kwargs) + + @classmethod + def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + return parse_argparser(cls, arg_parser) + + @classmethod + def match_env_arguments(cls) -> Namespace: + return parse_env_variables(cls) + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + return add_argparse_args(cls, parent_parser, **kwargs) + + """ + State properties + """ + + @property + def interrupted(self) -> bool: + return self.state.status == TrainerStatus.INTERRUPTED + + @property + def training(self) -> bool: + return self.state.stage == RunningStage.TRAINING + + @training.setter + def training(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.TRAINING + elif self.training: + self.state.stage = None + + @property + def testing(self) -> bool: + return self.state.stage == RunningStage.TESTING + + @testing.setter + def testing(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.TESTING + elif self.testing: + self.state.stage = None + + @property + def predicting(self) -> bool: + return self.state.stage == RunningStage.PREDICTING + + @predicting.setter + def predicting(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.PREDICTING + elif self.predicting: + self.state.stage = None + + @property + def tuning(self) -> bool: + return self.state.stage == RunningStage.TUNING + + @tuning.setter + def tuning(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.TUNING + elif self.tuning: + self.state.stage = None + + @property + def validating(self) -> bool: + return self.state.stage == RunningStage.VALIDATING + + @validating.setter + def validating(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.VALIDATING + elif self.validating: + self.state.stage = None + + @property + def evaluating(self) -> bool: + return self.state.stage and self.state.stage.evaluating + + @property + def sanity_checking(self) -> bool: + return self.state.stage == RunningStage.SANITY_CHECKING + + @sanity_checking.setter + def sanity_checking(self, val: bool) -> None: + if val: + self.state.stage = RunningStage.SANITY_CHECKING + elif self.sanity_checking: + self.state.stage = None + + """ + Loop properties + """ + + @property + def global_step(self) -> int: + return self.fit_loop.global_step + + @property + def current_epoch(self) -> int: + return self.fit_loop.current_epoch + + @property + def max_epochs(self) -> Optional[int]: + return self.fit_loop.max_epochs + + @property + def min_epochs(self) -> Optional[int]: + return self.fit_loop.min_epochs + + @property + def max_steps(self) -> Optional[int]: + return self.fit_loop.max_steps + + @property + def min_steps(self) -> Optional[int]: + return self.fit_loop.min_steps + + @property + def is_last_batch(self) -> bool: + return self.fit_loop.epoch_loop.is_last_batch + + @property + def fit_loop(self) -> FitLoop: + return self._fit_loop + + @fit_loop.setter + def fit_loop(self, loop: FitLoop): + """Attach a custom fit loop to this Trainer. + + It will run with + :meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`. + """ + loop.trainer = self + self._fit_loop = loop + + @property + def validate_loop(self) -> EvaluationLoop: + return self._validate_loop + + @validate_loop.setter + def validate_loop(self, loop: EvaluationLoop): + """Attach a custom validation loop to this Trainer. + + It will run with + :meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one + running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call. + """ + loop.trainer = self + self._validate_loop = loop + + @property + def test_loop(self) -> EvaluationLoop: + return self._test_loop + + @test_loop.setter + def test_loop(self, loop: EvaluationLoop): + """Attach a custom test loop to this Trainer. + + It will run with + :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. + """ + loop.trainer = self + self._test_loop = loop + + @property + def predict_loop(self) -> PredictionLoop: + return self._predict_loop + + @predict_loop.setter + def predict_loop(self, loop: PredictionLoop): + """Attach a custom prediction loop to this Trainer. + + It will run with + :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. + """ + loop.trainer = self + self._predict_loop = loop + + @property + def _evaluation_loop(self) -> EvaluationLoop: + if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + return self.fit_loop.epoch_loop.val_loop + if self.state.fn == TrainerFn.VALIDATING: + return self.validate_loop + if self.state.fn == TrainerFn.TESTING: + return self.test_loop + raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope") + + @property + def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop]]: + if self.training: + return self.fit_loop + if self.sanity_checking or self.evaluating: + return self._evaluation_loop + if self.predicting: + return self.predict_loop + + @property + def _ckpt_path(self) -> Optional[str]: + if self.state.fn == TrainerFn.VALIDATING: + return self.validated_ckpt_path + if self.state.fn == TrainerFn.TESTING: + return self.tested_ckpt_path + if self.state.fn == TrainerFn.PREDICTING: + return self.predicted_ckpt_path + + """ + Logging properties + """ + + @property + def callback_metrics(self) -> dict: + return self.logger_connector.callback_metrics + + @property + def logged_metrics(self) -> dict: + return self.logger_connector.logged_metrics + + @property + def progress_bar_metrics(self) -> dict: + return self.logger_connector.progress_bar_metrics + + @property + def _results(self) -> Optional[ResultCollection]: + active_loop = self._active_loop + if active_loop is not None: + return active_loop._results + + """ + Other + """ + + # TODO: refactor this so that it can be done in LightningOptimizer + def __getstate__(self): + # remove lightning_optimizers + self._lightning_optimizers = None + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5f1bdd1f34..14d10157c8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -685,7 +685,6 @@ def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): trainer_fn = getattr(trainer, fn) path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path" - assert getattr(trainer, path_attr) is None if ckpt_path == "best": # ckpt_path is 'best', meaning we load the best weights