From 435e479bbd4f59021682516b735d2410676f9ba9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 9 Sep 2020 20:03:18 -0400 Subject: [PATCH] ref: separate properties (#3432) * ref: separate properties * ref: separate properties * ref: separate properties * ref: separate properties --- pytorch_lightning/trainer/model_connector.py | 5 + pytorch_lightning/trainer/properties.py | 152 +++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 125 +-------------- 3 files changed, 159 insertions(+), 123 deletions(-) create mode 100644 pytorch_lightning/trainer/properties.py diff --git a/pytorch_lightning/trainer/model_connector.py b/pytorch_lightning/trainer/model_connector.py index d57a40cf4c..ca7fce6e2d 100644 --- a/pytorch_lightning/trainer/model_connector.py +++ b/pytorch_lightning/trainer/model_connector.py @@ -50,3 +50,8 @@ class ModelConnector: m.precision = self.trainer.precision m.global_rank = self.trainer.global_rank m.local_rank = self.trainer.local_rank + + def get_model(self): + is_dp_module = isinstance(self.trainer.model, (LightningDistributedDataParallel, LightningDataParallel)) + model = self.trainer.model.module if is_dp_module else self.trainer.model + return model diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py new file mode 100644 index 0000000000..0b0d429e7b --- /dev/null +++ b/pytorch_lightning/trainer/properties.py @@ -0,0 +1,152 @@ +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.trainer.logger_connector import LoggerConnector +from pytorch_lightning.trainer.states import TrainerState +from typing import List, Optional, Union +from pytorch_lightning.utilities import argparse_utils +from argparse import ArgumentParser, Namespace +from abc import ABC +import inspect +import os +from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.callbacks import ProgressBarBase +from pytorch_lightning.trainer.model_connector import ModelConnector + + +class TrainerProperties(ABC): + + precision: int + logger_connector: LoggerConnector + _state: TrainerState + global_rank: int + fast_dev_run: bool + use_dp: bool + use_ddp: bool + use_ddp2: bool + model: LightningModule + data_parallel_device_ids: Optional[List[int]] + _progress_bar_callback: ProgressBarBase + limit_val_batches: int + _default_root_dir: str + _weights_save_path: str + model_connector: ModelConnector + + @property + def use_amp(self) -> bool: + return self.precision == 16 + + @property + def callback_metrics(self): + return self.logger_connector.callback_metrics + + @callback_metrics.setter + def callback_metrics(self, x): + self.logger_connector.callback_metrics = x + + @property + def state(self) -> TrainerState: + return self._state + + @property + def is_global_zero(self) -> bool: + return self.global_rank == 0 + + @property + def slurm_job_id(self) -> Optional[int]: + try: + job_id = os.environ['SLURM_JOB_ID'] + job_id = int(job_id) + + # in interactive mode, don't make logs use the same job id + in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash' + if in_slurm_interactive_mode: + job_id = None + + except Exception: + job_id = None + return job_id + + @classmethod + def default_attributes(cls): + init_signature = inspect.signature(cls) + + args = {} + for param_name in init_signature.parameters: + value = init_signature.parameters[param_name].default + args[param_name] = value + + return args + + @classmethod + def get_deprecated_arg_names(cls) -> List: + """Returns a list with deprecated Trainer arguments.""" + depr_arg_names = [] + for name, val in cls.__dict__.items(): + if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): + depr_arg_names.extend(val) + return depr_arg_names + + @classmethod + def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): + return argparse_utils.from_argparse_args(cls, args, **kwargs) + + @classmethod + def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + return argparse_utils.parse_argparser(cls, arg_parser) + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: + return argparse_utils.add_argparse_args(cls, parent_parser) + + @property + def num_gpus(self) -> int: + gpus = self.data_parallel_device_ids + if gpus is None: + return 0 + return len(gpus) + + @property + def data_parallel(self) -> bool: + return self.use_dp or self.use_ddp or self.use_ddp2 + + @property + def progress_bar_callback(self): + return self._progress_bar_callback + + @property + def progress_bar_dict(self) -> dict: + """ Read-only for progress bar metrics. """ + ref_model = self.model if not self.data_parallel else self.model.module + return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) + + @property + def disable_validation(self) -> bool: + """ Check if validation is disabled during training. """ + return not self.enable_validation + + @property + def enable_validation(self) -> bool: + """ Check if we should run validation during training. """ + model_ref = self.model_connector.get_model() + val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0 + return val_loop_enabled or self.fast_dev_run + + @property + def default_root_dir(self) -> str: + """ + The default location to save artifacts of loggers, checkpoints etc. + It is used as a fallback if logger or checkpoint callback do not define specific save paths. + """ + if get_filesystem(self._default_root_dir).protocol == "file": + return os.path.normpath(self._default_root_dir) + return self._default_root_dir + + @property + def weights_save_path(self) -> str: + """ + The default root location to save weights (checkpoints), e.g., when the + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. + """ + if get_filesystem(self._weights_save_path).protocol == "file": + return os.path.normpath(self._weights_save_path) + return self._weights_save_path diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 12f2064d6c..ff42fae26e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import os import warnings -from argparse import ArgumentParser, Namespace from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch @@ -47,12 +45,10 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector -from pytorch_lightning.utilities import argparse_utils from pytorch_lightning.trainer.logger_connector import LoggerConnector from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector from pytorch_lightning.trainer.model_connector import ModelConnector @@ -61,6 +57,7 @@ from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.trainer.initializer import Initializer from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer import docstrings +from pytorch_lightning.trainer.properties import TrainerProperties # warnings to ignore in trainer warnings.filterwarnings( @@ -90,6 +87,7 @@ else: class Trainer( + TrainerProperties, TrainerIOMixin, TrainerCallbackHookMixin, TrainerModelHooksMixin, @@ -423,125 +421,6 @@ class Trainer( # Callback system self.on_init_end() - @property - def use_amp(self) -> bool: - return self.precision == 16 - - @property - def callback_metrics(self): - return self.logger_connector.callback_metrics - - @callback_metrics.setter - def callback_metrics(self, x): - self.logger_connector.callback_metrics = x - - @property - def state(self) -> TrainerState: - return self._state - - @property - def is_global_zero(self) -> bool: - return self.global_rank == 0 - - @property - def slurm_job_id(self) -> Optional[int]: - try: - job_id = os.environ['SLURM_JOB_ID'] - job_id = int(job_id) - - # in interactive mode, don't make logs use the same job id - in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash' - if in_slurm_interactive_mode: - job_id = None - - except Exception: - job_id = None - return job_id - - @classmethod - def default_attributes(cls): - init_signature = inspect.signature(Trainer) - - args = {} - for param_name in init_signature.parameters: - value = init_signature.parameters[param_name].default - args[param_name] = value - - return args - - @classmethod - def get_deprecated_arg_names(cls) -> List: - """Returns a list with deprecated Trainer arguments.""" - depr_arg_names = [] - for name, val in cls.__dict__.items(): - if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): - depr_arg_names.extend(val) - return depr_arg_names - - @classmethod - def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer': - return argparse_utils.from_argparse_args(cls, args, **kwargs) - - @classmethod - def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: - return argparse_utils.parse_argparser(cls, arg_parser) - - @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - return argparse_utils.add_argparse_args(cls, parent_parser) - - @property - def num_gpus(self) -> int: - gpus = self.data_parallel_device_ids - if gpus is None: - return 0 - return len(gpus) - - @property - def data_parallel(self) -> bool: - return self.use_dp or self.use_ddp or self.use_ddp2 - - @property - def progress_bar_callback(self): - return self._progress_bar_callback - - @property - def progress_bar_dict(self) -> dict: - """ Read-only for progress bar metrics. """ - ref_model = self.model if not self.data_parallel else self.model.module - return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) - - @property - def disable_validation(self) -> bool: - """ Check if validation is disabled during training. """ - return not self.enable_validation - - @property - def enable_validation(self) -> bool: - """ Check if we should run validation during training. """ - val_loop_enabled = is_overridden('validation_step', self.get_model()) and self.limit_val_batches > 0 - return val_loop_enabled or self.fast_dev_run - - @property - def default_root_dir(self) -> str: - """ - The default location to save artifacts of loggers, checkpoints etc. - It is used as a fallback if logger or checkpoint callback do not define specific save paths. - """ - if get_filesystem(self._default_root_dir).protocol == "file": - return os.path.normpath(self._default_root_dir) - return self._default_root_dir - - @property - def weights_save_path(self) -> str: - """ - The default root location to save weights (checkpoints), e.g., when the - :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. - """ - if get_filesystem(self._weights_save_path).protocol == "file": - return os.path.normpath(self._weights_save_path) - return self._weights_save_path - def tune( self, model: LightningModule,