add missing typing to trainer properties (#5974)

* add typing

* clean up

* isort

* fix typing in log_dir
This commit is contained in:
Adrian Wälchli 2021-02-16 00:54:12 +01:00 committed by GitHub
parent aa60c08641
commit 6dba26666a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 55 additions and 56 deletions

View File

@ -15,16 +15,21 @@ import inspect
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from typing import Any, cast, List, Optional, Type, TypeVar, Union
from typing import cast, List, Optional, Type, TypeVar, Union
import torch
from torch.optim import Optimizer
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
@ -40,84 +45,78 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
class TrainerProperties(ABC):
precision: int
logger_connector: LoggerConnector
_state: TrainerState
global_rank: int
fast_dev_run: Union[int, bool]
_device_type: DeviceType
_distrib_type: DistributedType
model: LightningModule
data_parallel_device_ids: Optional[List[int]]
_progress_bar_callback: ProgressBarBase
limit_val_batches: int
_default_root_dir: str
_weights_save_path: str
accelerator_backend: Accelerator
num_nodes: int
num_processes: int
accelerator_connector: BackendConnector
_lightning_optimizers = None
_progress_bar_callback: ProgressBarBase
_state: TrainerState
_weights_save_path: str
accelerator_connector: BackendConnector
callbacks: List[Callback]
checkpoint_connector: CheckpointConnector
limit_val_batches: int
logger: LightningLoggerBase
logger_connector: LoggerConnector
@property
def accelerator(self):
def accelerator(self) -> Accelerator:
return self.accelerator_connector.accelerator
@property
def accelerator_backend(self):
def accelerator_backend(self) -> Accelerator:
# for backward compatibility
return self.accelerator
@property
def distributed_backend(self):
def distributed_backend(self) -> Optional[str]:
# for backward compatibility
return self.accelerator_connector.distributed_backend
@property
def training_type_plugin(self):
def training_type_plugin(self) -> TrainingTypePlugin:
return self.accelerator.training_type_plugin
@property
def precision_plugin(self):
def precision_plugin(self) -> PrecisionPlugin:
return self.accelerator.precision_plugin
@property
def global_rank(self):
def global_rank(self) -> int:
return self.accelerator.training_type_plugin.global_rank
@property
def local_rank(self):
def local_rank(self) -> int:
# some training types define a local rank
return getattr(self.accelerator.training_type_plugin, "local_rank", 0)
@property
def node_rank(self):
def node_rank(self) -> int:
# some training types define a local rank
return getattr(self.accelerator.training_type_plugin, "node_rank", 0)
@property
def world_size(self):
def world_size(self) -> int:
# some training types define a world size
return getattr(self.accelerator.training_type_plugin, "world_size", 1)
@property
def _distrib_type(self):
def _distrib_type(self) -> DistributedType:
return self.accelerator_connector._distrib_type
@property
def _device_type(self):
def _device_type(self) -> DeviceType:
return self.accelerator_connector._device_type
@property
def num_nodes(self):
def num_nodes(self) -> int:
return self.accelerator_connector.num_nodes
@property
def num_processes(self):
def num_processes(self) -> int:
return self.accelerator_connector.num_processes
@property
def root_gpu(self):
def root_gpu(self) -> Optional[int]:
return self.accelerator_connector.root_gpu
@property
@ -129,11 +128,11 @@ class TrainerProperties(ABC):
return self.accelerator_connector.num_gpus
@property
def data_parallel_device_ids(self):
def data_parallel_device_ids(self) -> Optional[List[int]]:
return self.accelerator_connector.parallel_device_ids
@property
def log_dir(self):
def log_dir(self) -> Optional[str]:
if self.logger is None:
dirpath = self.default_root_dir
else:
@ -147,27 +146,27 @@ class TrainerProperties(ABC):
return self.precision == 16
@property
def callback_metrics(self):
def callback_metrics(self) -> dict:
return self.logger_connector.callback_metrics
@callback_metrics.setter
def callback_metrics(self, x):
def callback_metrics(self, x: dict) -> None:
self.logger_connector.callback_metrics = x
@property
def logged_metrics(self):
def logged_metrics(self) -> dict:
return self.logger_connector.logged_metrics
@logged_metrics.setter
def logged_metrics(self, x):
def logged_metrics(self, x: dict) -> None:
self.logger_connector.logged_metrics = x
@property
def progress_bar_metrics(self):
def progress_bar_metrics(self) -> dict:
return self.logger_connector.progress_bar_metrics
@progress_bar_metrics.setter
def progress_bar_metrics(self, x):
def progress_bar_metrics(self, x: dict) -> None:
self.logger_connector.progress_bar_metrics = x
@property
@ -194,7 +193,7 @@ class TrainerProperties(ABC):
return job_id
@classmethod
def default_attributes(cls):
def default_attributes(cls) -> dict:
init_signature = inspect.signature(cls)
args = {}
@ -240,7 +239,7 @@ class TrainerProperties(ABC):
)
@property
def progress_bar_callback(self):
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
return self._progress_bar_callback
@property
@ -329,11 +328,11 @@ class TrainerProperties(ABC):
"""
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
def save_checkpoint(self, filepath, weights_only: bool = False):
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
@property
def model(self) -> Any:
def model(self) -> torch.nn.Module:
"""
The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
To access the pure LightningModule, use
@ -342,7 +341,7 @@ class TrainerProperties(ABC):
return self.accelerator.model
@model.setter
def model(self, model: torch.nn.Module):
def model(self, model: torch.nn.Module) -> None:
"""
Setter for the model, pass-through to accelerator and plugin where the model reference is stored.
Used by the Tuner to reset the state of Trainer and Accelerator.
@ -353,51 +352,51 @@ class TrainerProperties(ABC):
"""
self.accelerator.model = model
def get_model(self):
def get_model(self) -> LightningModule:
# TODO: rename this to lightning_module (see training type plugin)
# backward compatible
return self.lightning_module
@property
def lightning_optimizers(self):
def lightning_optimizers(self) -> List[LightningOptimizer]:
if self._lightning_optimizers is None:
self.convert_to_lightning_optimizers()
return self._lightning_optimizers
@property
def lightning_module(self):
def lightning_module(self) -> LightningModule:
return self.training_type_plugin.lightning_module
@property
def optimizers(self):
def optimizers(self) -> Optional[List[Optimizer]]:
return self.accelerator.optimizers
@optimizers.setter
def optimizers(self, new_optims):
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
self.accelerator.optimizers = new_optims
@property
def lr_schedulers(self):
def lr_schedulers(self) -> Optional[list]:
return self.accelerator.lr_schedulers
@lr_schedulers.setter
def lr_schedulers(self, new_schedulers):
def lr_schedulers(self, new_schedulers: Optional[list]) -> None:
self.accelerator.lr_schedulers = new_schedulers
@property
def optimizer_frequencies(self):
def optimizer_frequencies(self) -> list:
return self.accelerator.optimizer_frequencies
@optimizer_frequencies.setter
def optimizer_frequencies(self, new_freqs):
def optimizer_frequencies(self, new_freqs: list) -> None:
self.accelerator.optimizer_frequencies = new_freqs
@property
def amp_backend(self):
def amp_backend(self) -> Optional[str]:
return self.accelerator.amp_backend
@property
def precision(self):
def precision(self) -> Union[str, int]:
return self.accelerator.precision
@property