Deprecate TrainerProperties Mixin and move property definitions directly into `trainer.py` (#9495)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
f5608e90d6
commit
290398f812
|
@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360))
|
||||
|
||||
|
||||
- Removed `TrainerProperties` mixin and moved property definitions directly into `Trainer` ([#9495](https://github.com/PyTorchLightning/pytorch-lightning/pull/9495))
|
||||
|
||||
|
||||
- Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496))
|
||||
|
||||
|
||||
|
|
|
@ -12,27 +12,33 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trainer to automate the training."""
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from weakref import proxy
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
|
||||
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.loggers.base import LoggerCollection
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
||||
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop
|
||||
from pytorch_lightning.plugins import DDPSpawnPlugin, PLUGIN_INPUT
|
||||
from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin
|
||||
from pytorch_lightning.profiler import (
|
||||
AdvancedProfiler,
|
||||
BaseProfiler,
|
||||
|
@ -50,6 +56,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
|||
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
|
||||
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
|
||||
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
|
||||
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
|
||||
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
|
||||
|
@ -58,8 +65,7 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
|||
from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes
|
||||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
from pytorch_lightning.trainer.properties import TrainerProperties
|
||||
from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
|
||||
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
|
||||
from pytorch_lightning.tuner.lr_finder import _LRFinder
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
|
@ -68,18 +74,33 @@ from pytorch_lightning.utilities import (
|
|||
_TPU_AVAILABLE,
|
||||
device_parser,
|
||||
DeviceType,
|
||||
DistributedType,
|
||||
parsing,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_info,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
add_argparse_args,
|
||||
from_argparse_args,
|
||||
parse_argparser,
|
||||
parse_env_variables,
|
||||
)
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
from pytorch_lightning.utilities.types import (
|
||||
_EVALUATE_OUTPUT,
|
||||
_PATH,
|
||||
_PREDICT_OUTPUT,
|
||||
EVAL_DATALOADERS,
|
||||
LRSchedulerTypeUnion,
|
||||
TRAIN_DATALOADERS,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# warnings to ignore in trainer
|
||||
|
@ -89,13 +110,15 @@ warnings.filterwarnings(
|
|||
|
||||
|
||||
class Trainer(
|
||||
TrainerProperties,
|
||||
TrainerCallbackHookMixin,
|
||||
TrainerModelHooksMixin,
|
||||
TrainerOptimizersMixin,
|
||||
TrainerDataLoadingMixin,
|
||||
DeprecatedTrainerAttributes,
|
||||
):
|
||||
# Needed because of LightningOptimizer
|
||||
_lightning_optimizers = None
|
||||
|
||||
@_defaults_from_env_vars
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1393,3 +1416,579 @@ class Trainer(
|
|||
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
|
||||
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
|
||||
self.save_checkpoint(file_path)
|
||||
|
||||
"""
|
||||
Accelerator properties
|
||||
"""
|
||||
|
||||
@property
|
||||
def accelerator(self) -> Accelerator:
|
||||
return self.accelerator_connector.accelerator
|
||||
|
||||
@property
|
||||
def distributed_backend(self) -> Optional[str]:
|
||||
# for backward compatibility
|
||||
return self.accelerator_connector.distributed_backend
|
||||
|
||||
@property
|
||||
def training_type_plugin(self) -> TrainingTypePlugin:
|
||||
return self.accelerator.training_type_plugin
|
||||
|
||||
@property
|
||||
def precision_plugin(self) -> PrecisionPlugin:
|
||||
return self.accelerator.precision_plugin
|
||||
|
||||
@property
|
||||
def global_rank(self) -> int:
|
||||
return self.accelerator.training_type_plugin.global_rank
|
||||
|
||||
@property
|
||||
def local_rank(self) -> int:
|
||||
# some training types define a local rank
|
||||
return getattr(self.accelerator.training_type_plugin, "local_rank", 0)
|
||||
|
||||
@property
|
||||
def node_rank(self) -> int:
|
||||
# some training types define a local rank
|
||||
return getattr(self.accelerator.training_type_plugin, "node_rank", 0)
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
# some training types define a world size
|
||||
return getattr(self.accelerator.training_type_plugin, "world_size", 1)
|
||||
|
||||
@property
|
||||
def should_rank_save_checkpoint(self) -> bool:
|
||||
return self.accelerator.training_type_plugin.should_rank_save_checkpoint
|
||||
|
||||
@property
|
||||
def _distrib_type(self) -> DistributedType:
|
||||
return self.accelerator_connector._distrib_type
|
||||
|
||||
@property
|
||||
def _device_type(self) -> DeviceType:
|
||||
return self.accelerator_connector._device_type
|
||||
|
||||
@property
|
||||
def num_nodes(self) -> int:
|
||||
return self.accelerator_connector.num_nodes
|
||||
|
||||
@property
|
||||
def num_processes(self) -> int:
|
||||
return self.accelerator_connector.num_processes
|
||||
|
||||
@property
|
||||
def root_gpu(self) -> Optional[int]:
|
||||
return self.accelerator_connector.root_gpu
|
||||
|
||||
@property
|
||||
def tpu_cores(self) -> int:
|
||||
return self.accelerator_connector.tpu_cores
|
||||
|
||||
@property
|
||||
def ipus(self) -> int:
|
||||
return self.accelerator_connector.num_ipus
|
||||
|
||||
@property
|
||||
def num_gpus(self) -> int:
|
||||
return self.accelerator_connector.num_gpus
|
||||
|
||||
@property
|
||||
def devices(self) -> Optional[Union[List[int], str, int]]:
|
||||
return self.accelerator_connector.devices
|
||||
|
||||
@property
|
||||
def data_parallel_device_ids(self) -> Optional[List[int]]:
|
||||
return self.accelerator_connector.parallel_device_ids
|
||||
|
||||
@property
|
||||
def lightning_module(self) -> "pl.LightningModule":
|
||||
return self.accelerator.lightning_module
|
||||
|
||||
@property
|
||||
def optimizers(self) -> List[Optimizer]:
|
||||
return self.accelerator.optimizers
|
||||
|
||||
@optimizers.setter
|
||||
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
|
||||
# Necessary to rewrap optimizers to lightning
|
||||
# They will be re-created when accessing
|
||||
# the `lightning_optimizers` trainer property
|
||||
self._lightning_optimizers = None
|
||||
|
||||
self.accelerator.optimizers = new_optims
|
||||
|
||||
@property
|
||||
def lr_schedulers(self) -> List[LRSchedulerTypeUnion]:
|
||||
return self.accelerator.lr_schedulers
|
||||
|
||||
@lr_schedulers.setter
|
||||
def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None:
|
||||
self.accelerator.lr_schedulers = new_schedulers
|
||||
|
||||
@property
|
||||
def optimizer_frequencies(self) -> list:
|
||||
return self.accelerator.optimizer_frequencies
|
||||
|
||||
@optimizer_frequencies.setter
|
||||
def optimizer_frequencies(self, new_freqs: list) -> None:
|
||||
self.accelerator.optimizer_frequencies = new_freqs
|
||||
|
||||
@property
|
||||
def amp_backend(self) -> Optional[str]:
|
||||
return self.accelerator.amp_backend
|
||||
|
||||
@property
|
||||
def precision(self) -> Union[str, int]:
|
||||
return self.accelerator.precision
|
||||
|
||||
@property
|
||||
def scaler(self):
|
||||
return self.accelerator.scaler
|
||||
|
||||
@property
|
||||
def gpus(self) -> Optional[Union[List[int], str, int]]:
|
||||
return self.accelerator_connector.gpus
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
"""The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
|
||||
|
||||
To access the pure LightningModule, use
|
||||
:meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead.
|
||||
"""
|
||||
return self.accelerator.model
|
||||
|
||||
@model.setter
|
||||
def model(self, model: torch.nn.Module) -> None:
|
||||
"""Setter for the model, pass-through to accelerator and plugin where the model reference is stored. Used
|
||||
by the Tuner to reset the state of Trainer and Accelerator.
|
||||
|
||||
Args:
|
||||
model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending
|
||||
on the backend.
|
||||
"""
|
||||
self.accelerator.model = model
|
||||
|
||||
"""
|
||||
General properties
|
||||
"""
|
||||
|
||||
@property
|
||||
def log_dir(self) -> Optional[str]:
|
||||
if self.logger is None:
|
||||
dirpath = self.default_root_dir
|
||||
elif isinstance(self.logger, TensorBoardLogger):
|
||||
dirpath = self.logger.log_dir
|
||||
elif isinstance(self.logger, LoggerCollection):
|
||||
dirpath = self.default_root_dir
|
||||
else:
|
||||
dirpath = self.logger.save_dir
|
||||
|
||||
dirpath = self.accelerator.broadcast(dirpath)
|
||||
return dirpath
|
||||
|
||||
@property
|
||||
def use_amp(self) -> bool:
|
||||
return self.precision == 16
|
||||
|
||||
@property
|
||||
def is_global_zero(self) -> bool:
|
||||
return self.global_rank == 0
|
||||
|
||||
@property
|
||||
def slurm_job_id(self) -> Optional[int]:
|
||||
job_id = os.environ.get("SLURM_JOB_ID")
|
||||
if job_id:
|
||||
try:
|
||||
job_id = int(job_id)
|
||||
except ValueError:
|
||||
job_id = None
|
||||
|
||||
# in interactive mode, don't make logs use the same job id
|
||||
in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash"
|
||||
if in_slurm_interactive_mode:
|
||||
job_id = None
|
||||
return job_id
|
||||
|
||||
@property
|
||||
def lightning_optimizers(self) -> List[LightningOptimizer]:
|
||||
if self._lightning_optimizers is None:
|
||||
self.convert_to_lightning_optimizers()
|
||||
return self._lightning_optimizers
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> Optional[dict]:
|
||||
if isinstance(self.training_type_plugin, ParallelPlugin):
|
||||
return self.training_type_plugin.distributed_sampler_kwargs
|
||||
|
||||
@property
|
||||
def data_parallel(self) -> bool:
|
||||
return self._distrib_type in (
|
||||
DistributedType.DP,
|
||||
DistributedType.DDP,
|
||||
DistributedType.DDP_SPAWN,
|
||||
DistributedType.DDP2,
|
||||
)
|
||||
|
||||
@property
|
||||
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
|
||||
return self._progress_bar_callback
|
||||
|
||||
@property
|
||||
def progress_bar_dict(self) -> dict:
|
||||
"""Read-only for progress bar metrics."""
|
||||
rank_zero_deprecation(
|
||||
"`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7."
|
||||
" Use `ProgressBarBase.get_metrics` instead."
|
||||
)
|
||||
ref_model = self.lightning_module
|
||||
ref_model = cast(pl.LightningModule, ref_model)
|
||||
if self.progress_bar_callback:
|
||||
return self.progress_bar_callback.get_metrics(self, ref_model)
|
||||
return self.progress_bar_metrics
|
||||
|
||||
@property
|
||||
def _should_reload_dl_epoch(self) -> bool:
|
||||
"""Check if dataloader should be reloaded in the current epoch."""
|
||||
n_epochs = self.reload_dataloaders_every_n_epochs
|
||||
return n_epochs and (not self.current_epoch % n_epochs)
|
||||
|
||||
@property
|
||||
def disable_validation(self) -> bool:
|
||||
"""Check if validation is disabled during training."""
|
||||
rank_zero_deprecation(
|
||||
"`trainer.disable_validation` is deprecated in v1.4 and will be removed in v1.6."
|
||||
" Use `not trainer.enable_validation` instead."
|
||||
)
|
||||
return not self.enable_validation
|
||||
|
||||
@property
|
||||
def enable_validation(self) -> bool:
|
||||
"""Check if we should run validation during training."""
|
||||
model_ref = self.lightning_module
|
||||
val_loop_enabled = is_overridden("validation_step", model_ref) and self.limit_val_batches > 0
|
||||
return val_loop_enabled
|
||||
|
||||
@property
|
||||
def default_root_dir(self) -> str:
|
||||
"""The default location to save artifacts of loggers, checkpoints etc.
|
||||
|
||||
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
|
||||
"""
|
||||
if get_filesystem(self._default_root_dir).protocol == "file":
|
||||
return os.path.normpath(self._default_root_dir)
|
||||
return self._default_root_dir
|
||||
|
||||
@property
|
||||
def weights_save_path(self) -> str:
|
||||
"""
|
||||
The default root location to save weights (checkpoints), e.g., when the
|
||||
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
|
||||
"""
|
||||
if get_filesystem(self._weights_save_path).protocol == "file":
|
||||
return os.path.normpath(self._weights_save_path)
|
||||
return self._weights_save_path
|
||||
|
||||
@property
|
||||
def early_stopping_callback(self) -> Optional[EarlyStopping]:
|
||||
"""The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback in the
|
||||
Trainer.callbacks list, or ``None`` if it doesn't exist."""
|
||||
callbacks = self.early_stopping_callbacks
|
||||
return callbacks[0] if len(callbacks) > 0 else None
|
||||
|
||||
@property
|
||||
def early_stopping_callbacks(self) -> List[EarlyStopping]:
|
||||
"""A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` found in
|
||||
the Trainer.callbacks list."""
|
||||
return [c for c in self.callbacks if isinstance(c, EarlyStopping)]
|
||||
|
||||
@property
|
||||
def prediction_writer_callbacks(self) -> List[BasePredictionWriter]:
|
||||
"""A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter`
|
||||
found in the Trainer.callbacks list."""
|
||||
return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)]
|
||||
|
||||
@property
|
||||
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
|
||||
"""The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the
|
||||
Trainer.callbacks list, or ``None`` if it doesn't exist."""
|
||||
callbacks = self.checkpoint_callbacks
|
||||
return callbacks[0] if len(callbacks) > 0 else None
|
||||
|
||||
@property
|
||||
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
|
||||
"""A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found
|
||||
in the Trainer.callbacks list."""
|
||||
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
|
||||
@property
|
||||
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
|
||||
return self.checkpoint_connector.resume_checkpoint_path
|
||||
|
||||
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
|
||||
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
|
||||
|
||||
"""
|
||||
Parsing properties
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def default_attributes(cls) -> dict:
|
||||
init_signature = inspect.signature(cls)
|
||||
return {k: v.default for k, v in init_signature.parameters.items()}
|
||||
|
||||
@classmethod
|
||||
def get_deprecated_arg_names(cls) -> List:
|
||||
"""Returns a list with deprecated Trainer arguments."""
|
||||
depr_arg_names = []
|
||||
for name, val in cls.__dict__.items():
|
||||
if name.startswith("DEPRECATED") and isinstance(val, (tuple, list)):
|
||||
depr_arg_names.extend(val)
|
||||
return depr_arg_names
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs) -> Any:
|
||||
return from_argparse_args(cls, args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
||||
return parse_argparser(cls, arg_parser)
|
||||
|
||||
@classmethod
|
||||
def match_env_arguments(cls) -> Namespace:
|
||||
return parse_env_variables(cls)
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
|
||||
return add_argparse_args(cls, parent_parser, **kwargs)
|
||||
|
||||
"""
|
||||
State properties
|
||||
"""
|
||||
|
||||
@property
|
||||
def interrupted(self) -> bool:
|
||||
return self.state.status == TrainerStatus.INTERRUPTED
|
||||
|
||||
@property
|
||||
def training(self) -> bool:
|
||||
return self.state.stage == RunningStage.TRAINING
|
||||
|
||||
@training.setter
|
||||
def training(self, val: bool) -> None:
|
||||
if val:
|
||||
self.state.stage = RunningStage.TRAINING
|
||||
elif self.training:
|
||||
self.state.stage = None
|
||||
|
||||
@property
|
||||
def testing(self) -> bool:
|
||||
return self.state.stage == RunningStage.TESTING
|
||||
|
||||
@testing.setter
|
||||
def testing(self, val: bool) -> None:
|
||||
if val:
|
||||
self.state.stage = RunningStage.TESTING
|
||||
elif self.testing:
|
||||
self.state.stage = None
|
||||
|
||||
@property
|
||||
def predicting(self) -> bool:
|
||||
return self.state.stage == RunningStage.PREDICTING
|
||||
|
||||
@predicting.setter
|
||||
def predicting(self, val: bool) -> None:
|
||||
if val:
|
||||
self.state.stage = RunningStage.PREDICTING
|
||||
elif self.predicting:
|
||||
self.state.stage = None
|
||||
|
||||
@property
|
||||
def tuning(self) -> bool:
|
||||
return self.state.stage == RunningStage.TUNING
|
||||
|
||||
@tuning.setter
|
||||
def tuning(self, val: bool) -> None:
|
||||
if val:
|
||||
self.state.stage = RunningStage.TUNING
|
||||
elif self.tuning:
|
||||
self.state.stage = None
|
||||
|
||||
@property
|
||||
def validating(self) -> bool:
|
||||
return self.state.stage == RunningStage.VALIDATING
|
||||
|
||||
@validating.setter
|
||||
def validating(self, val: bool) -> None:
|
||||
if val:
|
||||
self.state.stage = RunningStage.VALIDATING
|
||||
elif self.validating:
|
||||
self.state.stage = None
|
||||
|
||||
@property
|
||||
def evaluating(self) -> bool:
|
||||
return self.state.stage and self.state.stage.evaluating
|
||||
|
||||
@property
|
||||
def sanity_checking(self) -> bool:
|
||||
return self.state.stage == RunningStage.SANITY_CHECKING
|
||||
|
||||
@sanity_checking.setter
|
||||
def sanity_checking(self, val: bool) -> None:
|
||||
if val:
|
||||
self.state.stage = RunningStage.SANITY_CHECKING
|
||||
elif self.sanity_checking:
|
||||
self.state.stage = None
|
||||
|
||||
"""
|
||||
Loop properties
|
||||
"""
|
||||
|
||||
@property
|
||||
def global_step(self) -> int:
|
||||
return self.fit_loop.global_step
|
||||
|
||||
@property
|
||||
def current_epoch(self) -> int:
|
||||
return self.fit_loop.current_epoch
|
||||
|
||||
@property
|
||||
def max_epochs(self) -> Optional[int]:
|
||||
return self.fit_loop.max_epochs
|
||||
|
||||
@property
|
||||
def min_epochs(self) -> Optional[int]:
|
||||
return self.fit_loop.min_epochs
|
||||
|
||||
@property
|
||||
def max_steps(self) -> Optional[int]:
|
||||
return self.fit_loop.max_steps
|
||||
|
||||
@property
|
||||
def min_steps(self) -> Optional[int]:
|
||||
return self.fit_loop.min_steps
|
||||
|
||||
@property
|
||||
def is_last_batch(self) -> bool:
|
||||
return self.fit_loop.epoch_loop.is_last_batch
|
||||
|
||||
@property
|
||||
def fit_loop(self) -> FitLoop:
|
||||
return self._fit_loop
|
||||
|
||||
@fit_loop.setter
|
||||
def fit_loop(self, loop: FitLoop):
|
||||
"""Attach a custom fit loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
:meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`.
|
||||
"""
|
||||
loop.trainer = self
|
||||
self._fit_loop = loop
|
||||
|
||||
@property
|
||||
def validate_loop(self) -> EvaluationLoop:
|
||||
return self._validate_loop
|
||||
|
||||
@validate_loop.setter
|
||||
def validate_loop(self, loop: EvaluationLoop):
|
||||
"""Attach a custom validation loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
:meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
|
||||
running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call.
|
||||
"""
|
||||
loop.trainer = self
|
||||
self._validate_loop = loop
|
||||
|
||||
@property
|
||||
def test_loop(self) -> EvaluationLoop:
|
||||
return self._test_loop
|
||||
|
||||
@test_loop.setter
|
||||
def test_loop(self, loop: EvaluationLoop):
|
||||
"""Attach a custom test loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
|
||||
"""
|
||||
loop.trainer = self
|
||||
self._test_loop = loop
|
||||
|
||||
@property
|
||||
def predict_loop(self) -> PredictionLoop:
|
||||
return self._predict_loop
|
||||
|
||||
@predict_loop.setter
|
||||
def predict_loop(self, loop: PredictionLoop):
|
||||
"""Attach a custom prediction loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
|
||||
"""
|
||||
loop.trainer = self
|
||||
self._predict_loop = loop
|
||||
|
||||
@property
|
||||
def _evaluation_loop(self) -> EvaluationLoop:
|
||||
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
|
||||
return self.fit_loop.epoch_loop.val_loop
|
||||
if self.state.fn == TrainerFn.VALIDATING:
|
||||
return self.validate_loop
|
||||
if self.state.fn == TrainerFn.TESTING:
|
||||
return self.test_loop
|
||||
raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope")
|
||||
|
||||
@property
|
||||
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop]]:
|
||||
if self.training:
|
||||
return self.fit_loop
|
||||
if self.sanity_checking or self.evaluating:
|
||||
return self._evaluation_loop
|
||||
if self.predicting:
|
||||
return self.predict_loop
|
||||
|
||||
@property
|
||||
def _ckpt_path(self) -> Optional[str]:
|
||||
if self.state.fn == TrainerFn.VALIDATING:
|
||||
return self.validated_ckpt_path
|
||||
if self.state.fn == TrainerFn.TESTING:
|
||||
return self.tested_ckpt_path
|
||||
if self.state.fn == TrainerFn.PREDICTING:
|
||||
return self.predicted_ckpt_path
|
||||
|
||||
"""
|
||||
Logging properties
|
||||
"""
|
||||
|
||||
@property
|
||||
def callback_metrics(self) -> dict:
|
||||
return self.logger_connector.callback_metrics
|
||||
|
||||
@property
|
||||
def logged_metrics(self) -> dict:
|
||||
return self.logger_connector.logged_metrics
|
||||
|
||||
@property
|
||||
def progress_bar_metrics(self) -> dict:
|
||||
return self.logger_connector.progress_bar_metrics
|
||||
|
||||
@property
|
||||
def _results(self) -> Optional[ResultCollection]:
|
||||
active_loop = self._active_loop
|
||||
if active_loop is not None:
|
||||
return active_loop._results
|
||||
|
||||
"""
|
||||
Other
|
||||
"""
|
||||
|
||||
# TODO: refactor this so that it can be done in LightningOptimizer
|
||||
def __getstate__(self):
|
||||
# remove lightning_optimizers
|
||||
self._lightning_optimizers = None
|
||||
return self.__dict__
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
|
|
|
@ -685,7 +685,6 @@ def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn):
|
|||
|
||||
trainer_fn = getattr(trainer, fn)
|
||||
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
|
||||
assert getattr(trainer, path_attr) is None
|
||||
|
||||
if ckpt_path == "best":
|
||||
# ckpt_path is 'best', meaning we load the best weights
|
||||
|
|
Loading…
Reference in New Issue