From 6dba26666aa564db414eb238d99a4213006d8220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 16 Feb 2021 00:54:12 +0100 Subject: [PATCH] add missing typing to trainer properties (#5974) * add typing * clean up * isort * fix typing in log_dir --- pytorch_lightning/trainer/properties.py | 111 ++++++++++++------------ 1 file changed, 55 insertions(+), 56 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 0a678f07e0..1f0cc52870 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -15,16 +15,21 @@ import inspect import os from abc import ABC from argparse import ArgumentParser, Namespace -from typing import Any, cast, List, Optional, Type, TypeVar, Union +from typing import cast, List, Optional, Type, TypeVar, Union import torch +from torch.optim import Optimizer from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.accelerators.accelerator_connector import BackendConnector from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn @@ -40,84 +45,78 @@ from pytorch_lightning.utilities.model_helpers import is_overridden class TrainerProperties(ABC): - precision: int - logger_connector: LoggerConnector - _state: TrainerState - global_rank: int - fast_dev_run: Union[int, bool] - _device_type: DeviceType - _distrib_type: DistributedType - model: LightningModule - data_parallel_device_ids: Optional[List[int]] - _progress_bar_callback: ProgressBarBase - limit_val_batches: int _default_root_dir: str - _weights_save_path: str - accelerator_backend: Accelerator - num_nodes: int - num_processes: int - accelerator_connector: BackendConnector _lightning_optimizers = None + _progress_bar_callback: ProgressBarBase + _state: TrainerState + _weights_save_path: str + + accelerator_connector: BackendConnector + callbacks: List[Callback] + checkpoint_connector: CheckpointConnector + limit_val_batches: int + logger: LightningLoggerBase + logger_connector: LoggerConnector @property - def accelerator(self): + def accelerator(self) -> Accelerator: return self.accelerator_connector.accelerator @property - def accelerator_backend(self): + def accelerator_backend(self) -> Accelerator: # for backward compatibility return self.accelerator @property - def distributed_backend(self): + def distributed_backend(self) -> Optional[str]: # for backward compatibility return self.accelerator_connector.distributed_backend @property - def training_type_plugin(self): + def training_type_plugin(self) -> TrainingTypePlugin: return self.accelerator.training_type_plugin @property - def precision_plugin(self): + def precision_plugin(self) -> PrecisionPlugin: return self.accelerator.precision_plugin @property - def global_rank(self): + def global_rank(self) -> int: return self.accelerator.training_type_plugin.global_rank @property - def local_rank(self): + def local_rank(self) -> int: # some training types define a local rank return getattr(self.accelerator.training_type_plugin, "local_rank", 0) @property - def node_rank(self): + def node_rank(self) -> int: # some training types define a local rank return getattr(self.accelerator.training_type_plugin, "node_rank", 0) @property - def world_size(self): + def world_size(self) -> int: # some training types define a world size return getattr(self.accelerator.training_type_plugin, "world_size", 1) @property - def _distrib_type(self): + def _distrib_type(self) -> DistributedType: return self.accelerator_connector._distrib_type @property - def _device_type(self): + def _device_type(self) -> DeviceType: return self.accelerator_connector._device_type @property - def num_nodes(self): + def num_nodes(self) -> int: return self.accelerator_connector.num_nodes @property - def num_processes(self): + def num_processes(self) -> int: return self.accelerator_connector.num_processes @property - def root_gpu(self): + def root_gpu(self) -> Optional[int]: return self.accelerator_connector.root_gpu @property @@ -129,11 +128,11 @@ class TrainerProperties(ABC): return self.accelerator_connector.num_gpus @property - def data_parallel_device_ids(self): + def data_parallel_device_ids(self) -> Optional[List[int]]: return self.accelerator_connector.parallel_device_ids @property - def log_dir(self): + def log_dir(self) -> Optional[str]: if self.logger is None: dirpath = self.default_root_dir else: @@ -147,27 +146,27 @@ class TrainerProperties(ABC): return self.precision == 16 @property - def callback_metrics(self): + def callback_metrics(self) -> dict: return self.logger_connector.callback_metrics @callback_metrics.setter - def callback_metrics(self, x): + def callback_metrics(self, x: dict) -> None: self.logger_connector.callback_metrics = x @property - def logged_metrics(self): + def logged_metrics(self) -> dict: return self.logger_connector.logged_metrics @logged_metrics.setter - def logged_metrics(self, x): + def logged_metrics(self, x: dict) -> None: self.logger_connector.logged_metrics = x @property - def progress_bar_metrics(self): + def progress_bar_metrics(self) -> dict: return self.logger_connector.progress_bar_metrics @progress_bar_metrics.setter - def progress_bar_metrics(self, x): + def progress_bar_metrics(self, x: dict) -> None: self.logger_connector.progress_bar_metrics = x @property @@ -194,7 +193,7 @@ class TrainerProperties(ABC): return job_id @classmethod - def default_attributes(cls): + def default_attributes(cls) -> dict: init_signature = inspect.signature(cls) args = {} @@ -240,7 +239,7 @@ class TrainerProperties(ABC): ) @property - def progress_bar_callback(self): + def progress_bar_callback(self) -> Optional[ProgressBarBase]: return self._progress_bar_callback @property @@ -329,11 +328,11 @@ class TrainerProperties(ABC): """ return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] - def save_checkpoint(self, filepath, weights_only: bool = False): + def save_checkpoint(self, filepath, weights_only: bool = False) -> None: self.checkpoint_connector.save_checkpoint(filepath, weights_only) @property - def model(self) -> Any: + def model(self) -> torch.nn.Module: """ The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. To access the pure LightningModule, use @@ -342,7 +341,7 @@ class TrainerProperties(ABC): return self.accelerator.model @model.setter - def model(self, model: torch.nn.Module): + def model(self, model: torch.nn.Module) -> None: """ Setter for the model, pass-through to accelerator and plugin where the model reference is stored. Used by the Tuner to reset the state of Trainer and Accelerator. @@ -353,51 +352,51 @@ class TrainerProperties(ABC): """ self.accelerator.model = model - def get_model(self): + def get_model(self) -> LightningModule: # TODO: rename this to lightning_module (see training type plugin) # backward compatible return self.lightning_module @property - def lightning_optimizers(self): + def lightning_optimizers(self) -> List[LightningOptimizer]: if self._lightning_optimizers is None: self.convert_to_lightning_optimizers() return self._lightning_optimizers @property - def lightning_module(self): + def lightning_module(self) -> LightningModule: return self.training_type_plugin.lightning_module @property - def optimizers(self): + def optimizers(self) -> Optional[List[Optimizer]]: return self.accelerator.optimizers @optimizers.setter - def optimizers(self, new_optims): + def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: self.accelerator.optimizers = new_optims @property - def lr_schedulers(self): + def lr_schedulers(self) -> Optional[list]: return self.accelerator.lr_schedulers @lr_schedulers.setter - def lr_schedulers(self, new_schedulers): + def lr_schedulers(self, new_schedulers: Optional[list]) -> None: self.accelerator.lr_schedulers = new_schedulers @property - def optimizer_frequencies(self): + def optimizer_frequencies(self) -> list: return self.accelerator.optimizer_frequencies @optimizer_frequencies.setter - def optimizer_frequencies(self, new_freqs): + def optimizer_frequencies(self, new_freqs: list) -> None: self.accelerator.optimizer_frequencies = new_freqs @property - def amp_backend(self): + def amp_backend(self) -> Optional[str]: return self.accelerator.amp_backend @property - def precision(self): + def precision(self) -> Union[str, int]: return self.accelerator.precision @property