ref: separate properties (#3432)

* ref: separate properties

* ref: separate properties

* ref: separate properties

* ref: separate properties
This commit is contained in:
William Falcon 2020-09-09 20:03:18 -04:00 committed by GitHub
parent 9696484153
commit 435e479bbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 159 additions and 123 deletions

View File

@ -50,3 +50,8 @@ class ModelConnector:
m.precision = self.trainer.precision m.precision = self.trainer.precision
m.global_rank = self.trainer.global_rank m.global_rank = self.trainer.global_rank
m.local_rank = self.trainer.local_rank m.local_rank = self.trainer.local_rank
def get_model(self):
is_dp_module = isinstance(self.trainer.model, (LightningDistributedDataParallel, LightningDataParallel))
model = self.trainer.model.module if is_dp_module else self.trainer.model
return model

View File

@ -0,0 +1,152 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
from typing import List, Optional, Union
from pytorch_lightning.utilities import argparse_utils
from argparse import ArgumentParser, Namespace
from abc import ABC
import inspect
import os
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.trainer.model_connector import ModelConnector
class TrainerProperties(ABC):
precision: int
logger_connector: LoggerConnector
_state: TrainerState
global_rank: int
fast_dev_run: bool
use_dp: bool
use_ddp: bool
use_ddp2: bool
model: LightningModule
data_parallel_device_ids: Optional[List[int]]
_progress_bar_callback: ProgressBarBase
limit_val_batches: int
_default_root_dir: str
_weights_save_path: str
model_connector: ModelConnector
@property
def use_amp(self) -> bool:
return self.precision == 16
@property
def callback_metrics(self):
return self.logger_connector.callback_metrics
@callback_metrics.setter
def callback_metrics(self, x):
self.logger_connector.callback_metrics = x
@property
def state(self) -> TrainerState:
return self._state
@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
@property
def slurm_job_id(self) -> Optional[int]:
try:
job_id = os.environ['SLURM_JOB_ID']
job_id = int(job_id)
# in interactive mode, don't make logs use the same job id
in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash'
if in_slurm_interactive_mode:
job_id = None
except Exception:
job_id = None
return job_id
@classmethod
def default_attributes(cls):
init_signature = inspect.signature(cls)
args = {}
for param_name in init_signature.parameters:
value = init_signature.parameters[param_name].default
args[param_name] = value
return args
@classmethod
def get_deprecated_arg_names(cls) -> List:
"""Returns a list with deprecated Trainer arguments."""
depr_arg_names = []
for name, val in cls.__dict__.items():
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names
@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs)
@classmethod
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return argparse_utils.parse_argparser(cls, arg_parser)
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
return argparse_utils.add_argparse_args(cls, parent_parser)
@property
def num_gpus(self) -> int:
gpus = self.data_parallel_device_ids
if gpus is None:
return 0
return len(gpus)
@property
def data_parallel(self) -> bool:
return self.use_dp or self.use_ddp or self.use_ddp2
@property
def progress_bar_callback(self):
return self._progress_bar_callback
@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.model if not self.data_parallel else self.model.module
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)
@property
def disable_validation(self) -> bool:
""" Check if validation is disabled during training. """
return not self.enable_validation
@property
def enable_validation(self) -> bool:
""" Check if we should run validation during training. """
model_ref = self.model_connector.get_model()
val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0
return val_loop_enabled or self.fast_dev_run
@property
def default_root_dir(self) -> str:
"""
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
return self._default_root_dir
@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if get_filesystem(self._weights_save_path).protocol == "file":
return os.path.normpath(self._weights_save_path)
return self._weights_save_path

View File

@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import os import os
import warnings import warnings
from argparse import ArgumentParser, Namespace
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
@ -47,12 +45,10 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities import argparse_utils
from pytorch_lightning.trainer.logger_connector import LoggerConnector from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.model_connector import ModelConnector from pytorch_lightning.trainer.model_connector import ModelConnector
@ -61,6 +57,7 @@ from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.initializer import Initializer from pytorch_lightning.trainer.initializer import Initializer
from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer import docstrings from pytorch_lightning.trainer import docstrings
from pytorch_lightning.trainer.properties import TrainerProperties
# warnings to ignore in trainer # warnings to ignore in trainer
warnings.filterwarnings( warnings.filterwarnings(
@ -90,6 +87,7 @@ else:
class Trainer( class Trainer(
TrainerProperties,
TrainerIOMixin, TrainerIOMixin,
TrainerCallbackHookMixin, TrainerCallbackHookMixin,
TrainerModelHooksMixin, TrainerModelHooksMixin,
@ -423,125 +421,6 @@ class Trainer(
# Callback system # Callback system
self.on_init_end() self.on_init_end()
@property
def use_amp(self) -> bool:
return self.precision == 16
@property
def callback_metrics(self):
return self.logger_connector.callback_metrics
@callback_metrics.setter
def callback_metrics(self, x):
self.logger_connector.callback_metrics = x
@property
def state(self) -> TrainerState:
return self._state
@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
@property
def slurm_job_id(self) -> Optional[int]:
try:
job_id = os.environ['SLURM_JOB_ID']
job_id = int(job_id)
# in interactive mode, don't make logs use the same job id
in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash'
if in_slurm_interactive_mode:
job_id = None
except Exception:
job_id = None
return job_id
@classmethod
def default_attributes(cls):
init_signature = inspect.signature(Trainer)
args = {}
for param_name in init_signature.parameters:
value = init_signature.parameters[param_name].default
args[param_name] = value
return args
@classmethod
def get_deprecated_arg_names(cls) -> List:
"""Returns a list with deprecated Trainer arguments."""
depr_arg_names = []
for name, val in cls.__dict__.items():
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names
@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
return argparse_utils.from_argparse_args(cls, args, **kwargs)
@classmethod
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return argparse_utils.parse_argparser(cls, arg_parser)
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
return argparse_utils.add_argparse_args(cls, parent_parser)
@property
def num_gpus(self) -> int:
gpus = self.data_parallel_device_ids
if gpus is None:
return 0
return len(gpus)
@property
def data_parallel(self) -> bool:
return self.use_dp or self.use_ddp or self.use_ddp2
@property
def progress_bar_callback(self):
return self._progress_bar_callback
@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.model if not self.data_parallel else self.model.module
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)
@property
def disable_validation(self) -> bool:
""" Check if validation is disabled during training. """
return not self.enable_validation
@property
def enable_validation(self) -> bool:
""" Check if we should run validation during training. """
val_loop_enabled = is_overridden('validation_step', self.get_model()) and self.limit_val_batches > 0
return val_loop_enabled or self.fast_dev_run
@property
def default_root_dir(self) -> str:
"""
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
return self._default_root_dir
@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if get_filesystem(self._weights_save_path).protocol == "file":
return os.path.normpath(self._weights_save_path)
return self._weights_save_path
def tune( def tune(
self, self,
model: LightningModule, model: LightningModule,