ref: separate properties (#3432)

* ref: separate properties

* ref: separate properties

* ref: separate properties

* ref: separate properties
This commit is contained in:
William Falcon 2020-09-09 20:03:18 -04:00 committed by GitHub
parent 9696484153
commit 435e479bbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 159 additions and 123 deletions

View File

@ -50,3 +50,8 @@ class ModelConnector:
m.precision = self.trainer.precision
m.global_rank = self.trainer.global_rank
m.local_rank = self.trainer.local_rank
def get_model(self):
is_dp_module = isinstance(self.trainer.model, (LightningDistributedDataParallel, LightningDataParallel))
model = self.trainer.model.module if is_dp_module else self.trainer.model
return model

View File

@ -0,0 +1,152 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
from typing import List, Optional, Union
from pytorch_lightning.utilities import argparse_utils
from argparse import ArgumentParser, Namespace
from abc import ABC
import inspect
import os
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.trainer.model_connector import ModelConnector
class TrainerProperties(ABC):
precision: int
logger_connector: LoggerConnector
_state: TrainerState
global_rank: int
fast_dev_run: bool
use_dp: bool
use_ddp: bool
use_ddp2: bool
model: LightningModule
data_parallel_device_ids: Optional[List[int]]
_progress_bar_callback: ProgressBarBase
limit_val_batches: int
_default_root_dir: str
_weights_save_path: str
model_connector: ModelConnector
@property
def use_amp(self) -> bool:
return self.precision == 16
@property
def callback_metrics(self):
return self.logger_connector.callback_metrics
@callback_metrics.setter
def callback_metrics(self, x):
self.logger_connector.callback_metrics = x
@property
def state(self) -> TrainerState:
return self._state
@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
@property
def slurm_job_id(self) -> Optional[int]:
try:
job_id = os.environ['SLURM_JOB_ID']
job_id = int(job_id)
# in interactive mode, don't make logs use the same job id
in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash'
if in_slurm_interactive_mode:
job_id = None
except Exception:
job_id = None
return job_id
@classmethod
def default_attributes(cls):
init_signature = inspect.signature(cls)
args = {}
for param_name in init_signature.parameters:
value = init_signature.parameters[param_name].default
args[param_name] = value
return args
@classmethod
def get_deprecated_arg_names(cls) -> List:
"""Returns a list with deprecated Trainer arguments."""
depr_arg_names = []
for name, val in cls.__dict__.items():
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names
@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs)
@classmethod
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return argparse_utils.parse_argparser(cls, arg_parser)
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
return argparse_utils.add_argparse_args(cls, parent_parser)
@property
def num_gpus(self) -> int:
gpus = self.data_parallel_device_ids
if gpus is None:
return 0
return len(gpus)
@property
def data_parallel(self) -> bool:
return self.use_dp or self.use_ddp or self.use_ddp2
@property
def progress_bar_callback(self):
return self._progress_bar_callback
@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.model if not self.data_parallel else self.model.module
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)
@property
def disable_validation(self) -> bool:
""" Check if validation is disabled during training. """
return not self.enable_validation
@property
def enable_validation(self) -> bool:
""" Check if we should run validation during training. """
model_ref = self.model_connector.get_model()
val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0
return val_loop_enabled or self.fast_dev_run
@property
def default_root_dir(self) -> str:
"""
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
return self._default_root_dir
@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if get_filesystem(self._weights_save_path).protocol == "file":
return os.path.normpath(self._weights_save_path)
return self._weights_save_path

View File

@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import warnings
from argparse import ArgumentParser, Namespace
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
@ -47,12 +45,10 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities import argparse_utils
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.model_connector import ModelConnector
@ -61,6 +57,7 @@ from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.initializer import Initializer
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer import docstrings
from pytorch_lightning.trainer.properties import TrainerProperties
# warnings to ignore in trainer
warnings.filterwarnings(
@ -90,6 +87,7 @@ else:
class Trainer(
TrainerProperties,
TrainerIOMixin,
TrainerCallbackHookMixin,
TrainerModelHooksMixin,
@ -423,125 +421,6 @@ class Trainer(
# Callback system
self.on_init_end()
@property
def use_amp(self) -> bool:
return self.precision == 16
@property
def callback_metrics(self):
return self.logger_connector.callback_metrics
@callback_metrics.setter
def callback_metrics(self, x):
self.logger_connector.callback_metrics = x
@property
def state(self) -> TrainerState:
return self._state
@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
@property
def slurm_job_id(self) -> Optional[int]:
try:
job_id = os.environ['SLURM_JOB_ID']
job_id = int(job_id)
# in interactive mode, don't make logs use the same job id
in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash'
if in_slurm_interactive_mode:
job_id = None
except Exception:
job_id = None
return job_id
@classmethod
def default_attributes(cls):
init_signature = inspect.signature(Trainer)
args = {}
for param_name in init_signature.parameters:
value = init_signature.parameters[param_name].default
args[param_name] = value
return args
@classmethod
def get_deprecated_arg_names(cls) -> List:
"""Returns a list with deprecated Trainer arguments."""
depr_arg_names = []
for name, val in cls.__dict__.items():
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names
@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
return argparse_utils.from_argparse_args(cls, args, **kwargs)
@classmethod
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return argparse_utils.parse_argparser(cls, arg_parser)
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
return argparse_utils.add_argparse_args(cls, parent_parser)
@property
def num_gpus(self) -> int:
gpus = self.data_parallel_device_ids
if gpus is None:
return 0
return len(gpus)
@property
def data_parallel(self) -> bool:
return self.use_dp or self.use_ddp or self.use_ddp2
@property
def progress_bar_callback(self):
return self._progress_bar_callback
@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.model if not self.data_parallel else self.model.module
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)
@property
def disable_validation(self) -> bool:
""" Check if validation is disabled during training. """
return not self.enable_validation
@property
def enable_validation(self) -> bool:
""" Check if we should run validation during training. """
val_loop_enabled = is_overridden('validation_step', self.get_model()) and self.limit_val_batches > 0
return val_loop_enabled or self.fast_dev_run
@property
def default_root_dir(self) -> str:
"""
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
return self._default_root_dir
@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if get_filesystem(self._weights_save_path).protocol == "file":
return os.path.normpath(self._weights_save_path)
return self._weights_save_path
def tune(
self,
model: LightningModule,