510 lines
17 KiB
Python
510 lines
17 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import inspect
|
|
import os
|
|
from abc import ABC
|
|
from argparse import ArgumentParser, Namespace
|
|
from typing import cast, List, Optional, Type, TypeVar, Union
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
from pytorch_lightning.accelerators import Accelerator
|
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
|
|
from pytorch_lightning.callbacks.base import Callback
|
|
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
|
from pytorch_lightning.loggers import LightningLoggerBase
|
|
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
|
from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin
|
|
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
|
|
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
|
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
|
from pytorch_lightning.trainer.states import RunningStage, TrainerState
|
|
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
|
|
from pytorch_lightning.utilities.argparse import (
|
|
add_argparse_args,
|
|
from_argparse_args,
|
|
parse_argparser,
|
|
parse_env_variables,
|
|
)
|
|
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
|
from pytorch_lightning.utilities.model_helpers import is_overridden
|
|
|
|
|
|
class TrainerProperties(ABC):
|
|
|
|
_default_root_dir: str
|
|
_lightning_optimizers = None
|
|
_progress_bar_callback: ProgressBarBase
|
|
_running_stage: Optional[RunningStage] = None
|
|
_state: TrainerState
|
|
_weights_save_path: str
|
|
|
|
accelerator_connector: AcceleratorConnector
|
|
callbacks: List[Callback]
|
|
checkpoint_connector: CheckpointConnector
|
|
limit_val_batches: int
|
|
logger: LightningLoggerBase
|
|
logger_connector: LoggerConnector
|
|
|
|
@property
|
|
def accelerator(self) -> Accelerator:
|
|
return self.accelerator_connector.accelerator
|
|
|
|
@property
|
|
def distributed_backend(self) -> Optional[str]:
|
|
# for backward compatibility
|
|
return self.accelerator_connector.distributed_backend
|
|
|
|
@property
|
|
def training_type_plugin(self) -> TrainingTypePlugin:
|
|
return self.accelerator.training_type_plugin
|
|
|
|
@property
|
|
def precision_plugin(self) -> PrecisionPlugin:
|
|
return self.accelerator.precision_plugin
|
|
|
|
@property
|
|
def global_rank(self) -> int:
|
|
return self.accelerator.training_type_plugin.global_rank
|
|
|
|
@property
|
|
def local_rank(self) -> int:
|
|
# some training types define a local rank
|
|
return getattr(self.accelerator.training_type_plugin, "local_rank", 0)
|
|
|
|
@property
|
|
def node_rank(self) -> int:
|
|
# some training types define a local rank
|
|
return getattr(self.accelerator.training_type_plugin, "node_rank", 0)
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
# some training types define a world size
|
|
return getattr(self.accelerator.training_type_plugin, "world_size", 1)
|
|
|
|
@property
|
|
def _distrib_type(self) -> DistributedType:
|
|
return self.accelerator_connector._distrib_type
|
|
|
|
@property
|
|
def _device_type(self) -> DeviceType:
|
|
return self.accelerator_connector._device_type
|
|
|
|
@property
|
|
def num_nodes(self) -> int:
|
|
return self.accelerator_connector.num_nodes
|
|
|
|
@property
|
|
def num_processes(self) -> int:
|
|
return self.accelerator_connector.num_processes
|
|
|
|
@property
|
|
def root_gpu(self) -> Optional[int]:
|
|
return self.accelerator_connector.root_gpu
|
|
|
|
@property
|
|
def tpu_cores(self) -> int:
|
|
return self.accelerator_connector.tpu_cores
|
|
|
|
@property
|
|
def num_gpus(self) -> int:
|
|
return self.accelerator_connector.num_gpus
|
|
|
|
@property
|
|
def data_parallel_device_ids(self) -> Optional[List[int]]:
|
|
return self.accelerator_connector.parallel_device_ids
|
|
|
|
@property
|
|
def log_dir(self) -> Optional[str]:
|
|
if self.logger is None:
|
|
dirpath = self.default_root_dir
|
|
else:
|
|
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')
|
|
|
|
dirpath = self.accelerator.broadcast(dirpath)
|
|
return dirpath
|
|
|
|
@property
|
|
def use_amp(self) -> bool:
|
|
return self.precision == 16
|
|
|
|
@property
|
|
def callback_metrics(self) -> dict:
|
|
return self.logger_connector.callback_metrics
|
|
|
|
@callback_metrics.setter
|
|
def callback_metrics(self, x: dict) -> None:
|
|
self.logger_connector.callback_metrics = x
|
|
|
|
@property
|
|
def logged_metrics(self) -> dict:
|
|
return self.logger_connector.logged_metrics
|
|
|
|
@logged_metrics.setter
|
|
def logged_metrics(self, x: dict) -> None:
|
|
self.logger_connector.logged_metrics = x
|
|
|
|
@property
|
|
def progress_bar_metrics(self) -> dict:
|
|
return self.logger_connector.progress_bar_metrics
|
|
|
|
@progress_bar_metrics.setter
|
|
def progress_bar_metrics(self, x: dict) -> None:
|
|
self.logger_connector.progress_bar_metrics = x
|
|
|
|
@property
|
|
def state(self) -> TrainerState:
|
|
return self._state
|
|
|
|
@state.setter
|
|
def state(self, state: TrainerState) -> None:
|
|
self._state = state
|
|
|
|
@property
|
|
def interrupted(self) -> bool:
|
|
return self._state == TrainerState.INTERRUPTED
|
|
|
|
@property
|
|
def is_global_zero(self) -> bool:
|
|
return self.global_rank == 0
|
|
|
|
@property
|
|
def slurm_job_id(self) -> Optional[int]:
|
|
job_id = os.environ.get('SLURM_JOB_ID')
|
|
if job_id:
|
|
try:
|
|
job_id = int(job_id)
|
|
except ValueError:
|
|
job_id = None
|
|
|
|
# in interactive mode, don't make logs use the same job id
|
|
in_slurm_interactive_mode = os.environ.get('SLURM_JOB_NAME') == 'bash'
|
|
if in_slurm_interactive_mode:
|
|
job_id = None
|
|
return job_id
|
|
|
|
@classmethod
|
|
def default_attributes(cls) -> dict:
|
|
init_signature = inspect.signature(cls)
|
|
return {k: v.default for k, v in init_signature.parameters.items()}
|
|
|
|
@classmethod
|
|
def get_deprecated_arg_names(cls) -> List:
|
|
"""Returns a list with deprecated Trainer arguments."""
|
|
depr_arg_names = []
|
|
for name, val in cls.__dict__.items():
|
|
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
|
|
depr_arg_names.extend(val)
|
|
return depr_arg_names
|
|
|
|
@classmethod
|
|
def from_argparse_args(cls: Type['_T'], args: Union[Namespace, ArgumentParser], **kwargs) -> '_T':
|
|
return from_argparse_args(cls, args, **kwargs)
|
|
|
|
@classmethod
|
|
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
|
return parse_argparser(cls, arg_parser)
|
|
|
|
@classmethod
|
|
def match_env_arguments(cls) -> Namespace:
|
|
return parse_env_variables(cls)
|
|
|
|
@classmethod
|
|
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
|
|
return add_argparse_args(cls, parent_parser, **kwargs)
|
|
|
|
@property
|
|
def gpus(self) -> Optional[Union[List[int], str, int]]:
|
|
return self.accelerator_connector.gpus
|
|
|
|
@property
|
|
def data_parallel(self) -> bool:
|
|
return self._distrib_type in (
|
|
DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2
|
|
)
|
|
|
|
@property
|
|
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
|
|
return self._progress_bar_callback
|
|
|
|
@property
|
|
def progress_bar_dict(self) -> dict:
|
|
""" Read-only for progress bar metrics. """
|
|
ref_model = self.lightning_module
|
|
ref_model = cast(LightningModule, ref_model)
|
|
|
|
standard_metrics = ref_model.get_progress_bar_dict()
|
|
logged_metrics = self.progress_bar_metrics
|
|
duplicates = list(standard_metrics.keys() & logged_metrics.keys())
|
|
if duplicates:
|
|
rank_zero_warn(
|
|
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
|
|
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
|
|
f" If this is undesired, change the name or override `get_progress_bar_dict()`"
|
|
f" in `LightingModule`.", UserWarning
|
|
)
|
|
all_metrics = dict(**standard_metrics)
|
|
all_metrics.update(**logged_metrics)
|
|
return all_metrics
|
|
|
|
@property
|
|
def disable_validation(self) -> bool:
|
|
""" Check if validation is disabled during training. """
|
|
return not self.enable_validation
|
|
|
|
@property
|
|
def enable_validation(self) -> bool:
|
|
""" Check if we should run validation during training. """
|
|
model_ref = self.lightning_module
|
|
val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0
|
|
return val_loop_enabled
|
|
|
|
@property
|
|
def default_root_dir(self) -> str:
|
|
"""
|
|
The default location to save artifacts of loggers, checkpoints etc.
|
|
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
|
|
"""
|
|
if get_filesystem(self._default_root_dir).protocol == "file":
|
|
return os.path.normpath(self._default_root_dir)
|
|
return self._default_root_dir
|
|
|
|
@property
|
|
def weights_save_path(self) -> str:
|
|
"""
|
|
The default root location to save weights (checkpoints), e.g., when the
|
|
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
|
|
"""
|
|
if get_filesystem(self._weights_save_path).protocol == "file":
|
|
return os.path.normpath(self._weights_save_path)
|
|
return self._weights_save_path
|
|
|
|
@property
|
|
def early_stopping_callback(self) -> Optional[EarlyStopping]:
|
|
"""
|
|
The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
|
|
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
|
|
"""
|
|
callbacks = self.early_stopping_callbacks
|
|
return callbacks[0] if len(callbacks) > 0 else None
|
|
|
|
@property
|
|
def early_stopping_callbacks(self) -> List[EarlyStopping]:
|
|
"""
|
|
A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
|
|
found in the Trainer.callbacks list.
|
|
"""
|
|
return [c for c in self.callbacks if isinstance(c, EarlyStopping)]
|
|
|
|
@property
|
|
def prediction_writer_callbacks(self) -> List[BasePredictionWriter]:
|
|
"""
|
|
A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter`
|
|
found in the Trainer.callbacks list.
|
|
"""
|
|
return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)]
|
|
|
|
@property
|
|
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
|
|
"""
|
|
The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
|
|
callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.
|
|
"""
|
|
callbacks = self.checkpoint_callbacks
|
|
return callbacks[0] if len(callbacks) > 0 else None
|
|
|
|
@property
|
|
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
|
|
"""
|
|
A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
|
|
found in the Trainer.callbacks list.
|
|
"""
|
|
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
|
|
|
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
|
|
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
|
|
|
|
@property
|
|
def model(self) -> torch.nn.Module:
|
|
"""
|
|
The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
|
|
To access the pure LightningModule, use
|
|
:meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead.
|
|
"""
|
|
return self.accelerator.model
|
|
|
|
@model.setter
|
|
def model(self, model: torch.nn.Module) -> None:
|
|
"""
|
|
Setter for the model, pass-through to accelerator and plugin where the model reference is stored.
|
|
Used by the Tuner to reset the state of Trainer and Accelerator.
|
|
|
|
Args:
|
|
model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending
|
|
on the backend.
|
|
"""
|
|
self.accelerator.model = model
|
|
|
|
@property
|
|
def lightning_optimizers(self) -> List[LightningOptimizer]:
|
|
if self._lightning_optimizers is None:
|
|
self.convert_to_lightning_optimizers()
|
|
return self._lightning_optimizers
|
|
|
|
@property
|
|
def lightning_module(self) -> LightningModule:
|
|
return self.accelerator.lightning_module
|
|
|
|
@property
|
|
def optimizers(self) -> Optional[List[Optimizer]]:
|
|
return self.accelerator.optimizers
|
|
|
|
@optimizers.setter
|
|
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
|
|
# Necessary to rewrap optimizers to lightning
|
|
# They will be re-created when accessing
|
|
# the `lightning_optimizers` trainer property
|
|
self._lightning_optimizers = None
|
|
|
|
self.accelerator.optimizers = new_optims
|
|
|
|
@property
|
|
def lr_schedulers(self) -> Optional[list]:
|
|
return self.accelerator.lr_schedulers
|
|
|
|
@lr_schedulers.setter
|
|
def lr_schedulers(self, new_schedulers: Optional[list]) -> None:
|
|
self.accelerator.lr_schedulers = new_schedulers
|
|
|
|
@property
|
|
def optimizer_frequencies(self) -> list:
|
|
return self.accelerator.optimizer_frequencies
|
|
|
|
@optimizer_frequencies.setter
|
|
def optimizer_frequencies(self, new_freqs: list) -> None:
|
|
self.accelerator.optimizer_frequencies = new_freqs
|
|
|
|
@property
|
|
def amp_backend(self) -> Optional[str]:
|
|
return self.accelerator.amp_backend
|
|
|
|
@property
|
|
def precision(self) -> Union[str, int]:
|
|
return self.accelerator.precision
|
|
|
|
@property
|
|
def scaler(self):
|
|
return self.accelerator.scaler
|
|
|
|
# TODO: refactor this so that it can be done in LightningOptimizer
|
|
def __getstate__(self):
|
|
# remove lightning_optimizers
|
|
self._lightning_optimizers = None
|
|
return self.__dict__
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__ = state
|
|
|
|
@property
|
|
def distributed_sampler_kwargs(self) -> Optional[dict]:
|
|
if isinstance(self.training_type_plugin, ParallelPlugin):
|
|
return self.training_type_plugin.distributed_sampler_kwargs
|
|
|
|
@property
|
|
def training(self) -> bool:
|
|
return self._running_stage == RunningStage.TRAINING
|
|
|
|
@training.setter
|
|
def training(self, val: bool) -> None:
|
|
if val:
|
|
self._running_stage = RunningStage.TRAINING
|
|
elif self.training:
|
|
self._running_stage = None
|
|
|
|
@property
|
|
def testing(self) -> bool:
|
|
return self._running_stage == RunningStage.TESTING
|
|
|
|
@testing.setter
|
|
def testing(self, val: bool) -> None:
|
|
if val:
|
|
self._running_stage = RunningStage.TESTING
|
|
elif self.testing:
|
|
self._running_stage = None
|
|
|
|
@property
|
|
def predicting(self) -> bool:
|
|
return self._running_stage == RunningStage.PREDICTING
|
|
|
|
@predicting.setter
|
|
def predicting(self, val: bool) -> None:
|
|
if val:
|
|
self._running_stage = RunningStage.PREDICTING
|
|
elif self.predicting:
|
|
self._running_stage = None
|
|
|
|
@property
|
|
def tuning(self) -> bool:
|
|
return self._running_stage == RunningStage.TUNING
|
|
|
|
@tuning.setter
|
|
def tuning(self, val: bool) -> None:
|
|
if val:
|
|
self._running_stage = RunningStage.TUNING
|
|
elif self.tuning:
|
|
self._running_stage = None
|
|
|
|
@property
|
|
def validating(self) -> bool:
|
|
return self._running_stage == RunningStage.VALIDATING
|
|
|
|
@validating.setter
|
|
def validating(self, val: bool) -> None:
|
|
if val:
|
|
self._running_stage = RunningStage.VALIDATING
|
|
elif self.validating:
|
|
self._running_stage = None
|
|
|
|
@property
|
|
def evaluating(self) -> bool:
|
|
return self._running_stage and self._running_stage.evaluating
|
|
|
|
@property
|
|
def sanity_checking(self) -> bool:
|
|
return self._running_stage == RunningStage.SANITY_CHECKING
|
|
|
|
@sanity_checking.setter
|
|
def sanity_checking(self, val: bool) -> None:
|
|
if val:
|
|
self._running_stage = RunningStage.SANITY_CHECKING
|
|
elif self.sanity_checking:
|
|
self._running_stage = None
|
|
|
|
@property
|
|
def _setup_state(self) -> TrainerState:
|
|
# 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders"
|
|
return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state
|
|
|
|
@property
|
|
def _teardown_state(self) -> Optional[TrainerState]:
|
|
if self.state.running:
|
|
return self._setup_state
|
|
|
|
|
|
# Used to represent the concrete type TrainerProperties class methods are called on.
|
|
_T = TypeVar('_T', bound=TrainerProperties)
|