add missing typing to trainer properties (#5974)
* add typing * clean up * isort * fix typing in log_dir
This commit is contained in:
parent
aa60c08641
commit
6dba26666a
|
@ -15,16 +15,21 @@ import inspect
|
|||
import os
|
||||
from abc import ABC
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, cast, List, Optional, Type, TypeVar, Union
|
||||
from typing import cast, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.accelerators import Accelerator
|
||||
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import ParallelPlugin
|
||||
from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin
|
||||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
||||
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
|
||||
|
@ -40,84 +45,78 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
|
|||
|
||||
class TrainerProperties(ABC):
|
||||
|
||||
precision: int
|
||||
logger_connector: LoggerConnector
|
||||
_state: TrainerState
|
||||
global_rank: int
|
||||
fast_dev_run: Union[int, bool]
|
||||
_device_type: DeviceType
|
||||
_distrib_type: DistributedType
|
||||
model: LightningModule
|
||||
data_parallel_device_ids: Optional[List[int]]
|
||||
_progress_bar_callback: ProgressBarBase
|
||||
limit_val_batches: int
|
||||
_default_root_dir: str
|
||||
_weights_save_path: str
|
||||
accelerator_backend: Accelerator
|
||||
num_nodes: int
|
||||
num_processes: int
|
||||
accelerator_connector: BackendConnector
|
||||
_lightning_optimizers = None
|
||||
_progress_bar_callback: ProgressBarBase
|
||||
_state: TrainerState
|
||||
_weights_save_path: str
|
||||
|
||||
accelerator_connector: BackendConnector
|
||||
callbacks: List[Callback]
|
||||
checkpoint_connector: CheckpointConnector
|
||||
limit_val_batches: int
|
||||
logger: LightningLoggerBase
|
||||
logger_connector: LoggerConnector
|
||||
|
||||
@property
|
||||
def accelerator(self):
|
||||
def accelerator(self) -> Accelerator:
|
||||
return self.accelerator_connector.accelerator
|
||||
|
||||
@property
|
||||
def accelerator_backend(self):
|
||||
def accelerator_backend(self) -> Accelerator:
|
||||
# for backward compatibility
|
||||
return self.accelerator
|
||||
|
||||
@property
|
||||
def distributed_backend(self):
|
||||
def distributed_backend(self) -> Optional[str]:
|
||||
# for backward compatibility
|
||||
return self.accelerator_connector.distributed_backend
|
||||
|
||||
@property
|
||||
def training_type_plugin(self):
|
||||
def training_type_plugin(self) -> TrainingTypePlugin:
|
||||
return self.accelerator.training_type_plugin
|
||||
|
||||
@property
|
||||
def precision_plugin(self):
|
||||
def precision_plugin(self) -> PrecisionPlugin:
|
||||
return self.accelerator.precision_plugin
|
||||
|
||||
@property
|
||||
def global_rank(self):
|
||||
def global_rank(self) -> int:
|
||||
return self.accelerator.training_type_plugin.global_rank
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
def local_rank(self) -> int:
|
||||
# some training types define a local rank
|
||||
return getattr(self.accelerator.training_type_plugin, "local_rank", 0)
|
||||
|
||||
@property
|
||||
def node_rank(self):
|
||||
def node_rank(self) -> int:
|
||||
# some training types define a local rank
|
||||
return getattr(self.accelerator.training_type_plugin, "node_rank", 0)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
def world_size(self) -> int:
|
||||
# some training types define a world size
|
||||
return getattr(self.accelerator.training_type_plugin, "world_size", 1)
|
||||
|
||||
@property
|
||||
def _distrib_type(self):
|
||||
def _distrib_type(self) -> DistributedType:
|
||||
return self.accelerator_connector._distrib_type
|
||||
|
||||
@property
|
||||
def _device_type(self):
|
||||
def _device_type(self) -> DeviceType:
|
||||
return self.accelerator_connector._device_type
|
||||
|
||||
@property
|
||||
def num_nodes(self):
|
||||
def num_nodes(self) -> int:
|
||||
return self.accelerator_connector.num_nodes
|
||||
|
||||
@property
|
||||
def num_processes(self):
|
||||
def num_processes(self) -> int:
|
||||
return self.accelerator_connector.num_processes
|
||||
|
||||
@property
|
||||
def root_gpu(self):
|
||||
def root_gpu(self) -> Optional[int]:
|
||||
return self.accelerator_connector.root_gpu
|
||||
|
||||
@property
|
||||
|
@ -129,11 +128,11 @@ class TrainerProperties(ABC):
|
|||
return self.accelerator_connector.num_gpus
|
||||
|
||||
@property
|
||||
def data_parallel_device_ids(self):
|
||||
def data_parallel_device_ids(self) -> Optional[List[int]]:
|
||||
return self.accelerator_connector.parallel_device_ids
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
def log_dir(self) -> Optional[str]:
|
||||
if self.logger is None:
|
||||
dirpath = self.default_root_dir
|
||||
else:
|
||||
|
@ -147,27 +146,27 @@ class TrainerProperties(ABC):
|
|||
return self.precision == 16
|
||||
|
||||
@property
|
||||
def callback_metrics(self):
|
||||
def callback_metrics(self) -> dict:
|
||||
return self.logger_connector.callback_metrics
|
||||
|
||||
@callback_metrics.setter
|
||||
def callback_metrics(self, x):
|
||||
def callback_metrics(self, x: dict) -> None:
|
||||
self.logger_connector.callback_metrics = x
|
||||
|
||||
@property
|
||||
def logged_metrics(self):
|
||||
def logged_metrics(self) -> dict:
|
||||
return self.logger_connector.logged_metrics
|
||||
|
||||
@logged_metrics.setter
|
||||
def logged_metrics(self, x):
|
||||
def logged_metrics(self, x: dict) -> None:
|
||||
self.logger_connector.logged_metrics = x
|
||||
|
||||
@property
|
||||
def progress_bar_metrics(self):
|
||||
def progress_bar_metrics(self) -> dict:
|
||||
return self.logger_connector.progress_bar_metrics
|
||||
|
||||
@progress_bar_metrics.setter
|
||||
def progress_bar_metrics(self, x):
|
||||
def progress_bar_metrics(self, x: dict) -> None:
|
||||
self.logger_connector.progress_bar_metrics = x
|
||||
|
||||
@property
|
||||
|
@ -194,7 +193,7 @@ class TrainerProperties(ABC):
|
|||
return job_id
|
||||
|
||||
@classmethod
|
||||
def default_attributes(cls):
|
||||
def default_attributes(cls) -> dict:
|
||||
init_signature = inspect.signature(cls)
|
||||
|
||||
args = {}
|
||||
|
@ -240,7 +239,7 @@ class TrainerProperties(ABC):
|
|||
)
|
||||
|
||||
@property
|
||||
def progress_bar_callback(self):
|
||||
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
|
||||
return self._progress_bar_callback
|
||||
|
||||
@property
|
||||
|
@ -329,11 +328,11 @@ class TrainerProperties(ABC):
|
|||
"""
|
||||
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
|
||||
def save_checkpoint(self, filepath, weights_only: bool = False):
|
||||
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
|
||||
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
def model(self) -> torch.nn.Module:
|
||||
"""
|
||||
The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
|
||||
To access the pure LightningModule, use
|
||||
|
@ -342,7 +341,7 @@ class TrainerProperties(ABC):
|
|||
return self.accelerator.model
|
||||
|
||||
@model.setter
|
||||
def model(self, model: torch.nn.Module):
|
||||
def model(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Setter for the model, pass-through to accelerator and plugin where the model reference is stored.
|
||||
Used by the Tuner to reset the state of Trainer and Accelerator.
|
||||
|
@ -353,51 +352,51 @@ class TrainerProperties(ABC):
|
|||
"""
|
||||
self.accelerator.model = model
|
||||
|
||||
def get_model(self):
|
||||
def get_model(self) -> LightningModule:
|
||||
# TODO: rename this to lightning_module (see training type plugin)
|
||||
# backward compatible
|
||||
return self.lightning_module
|
||||
|
||||
@property
|
||||
def lightning_optimizers(self):
|
||||
def lightning_optimizers(self) -> List[LightningOptimizer]:
|
||||
if self._lightning_optimizers is None:
|
||||
self.convert_to_lightning_optimizers()
|
||||
return self._lightning_optimizers
|
||||
|
||||
@property
|
||||
def lightning_module(self):
|
||||
def lightning_module(self) -> LightningModule:
|
||||
return self.training_type_plugin.lightning_module
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
def optimizers(self) -> Optional[List[Optimizer]]:
|
||||
return self.accelerator.optimizers
|
||||
|
||||
@optimizers.setter
|
||||
def optimizers(self, new_optims):
|
||||
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
|
||||
self.accelerator.optimizers = new_optims
|
||||
|
||||
@property
|
||||
def lr_schedulers(self):
|
||||
def lr_schedulers(self) -> Optional[list]:
|
||||
return self.accelerator.lr_schedulers
|
||||
|
||||
@lr_schedulers.setter
|
||||
def lr_schedulers(self, new_schedulers):
|
||||
def lr_schedulers(self, new_schedulers: Optional[list]) -> None:
|
||||
self.accelerator.lr_schedulers = new_schedulers
|
||||
|
||||
@property
|
||||
def optimizer_frequencies(self):
|
||||
def optimizer_frequencies(self) -> list:
|
||||
return self.accelerator.optimizer_frequencies
|
||||
|
||||
@optimizer_frequencies.setter
|
||||
def optimizer_frequencies(self, new_freqs):
|
||||
def optimizer_frequencies(self, new_freqs: list) -> None:
|
||||
self.accelerator.optimizer_frequencies = new_freqs
|
||||
|
||||
@property
|
||||
def amp_backend(self):
|
||||
def amp_backend(self) -> Optional[str]:
|
||||
return self.accelerator.amp_backend
|
||||
|
||||
@property
|
||||
def precision(self):
|
||||
def precision(self) -> Union[str, int]:
|
||||
return self.accelerator.precision
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue