ref: clean config [1/n] add intermediate setters (#4990)

* add intermediate setters

* show inputs

* fix options

* move

* fix

* less talk

* fix

* talk less

* str

* cases

* rename

Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Jirka Borovec 2020-12-09 20:13:57 +01:00 committed by GitHub
parent 068502f07c
commit ce9179591d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 228 additions and 27 deletions

View File

@ -256,14 +256,4 @@ class Accelerator(object):
yield cm
# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
class BackendType(Enum):
DP = 'dp'
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
# decuple distrib and device
DDP_CPU = 'ddp_cpu'
HOROVOD = 'horovod'
# this is rather device
TPU = 'tpu'

View File

@ -335,6 +335,7 @@ class AcceleratorConnector:
self.trainer.use_ddp = True
self.trainer.data_parallel_device_ids = None
self.trainer.on_gpu = False
self.trainer.on_cpu = True
elif self.trainer.distributed_backend == "horovod":
self._set_horovod_backend()

View File

@ -21,6 +21,12 @@ from pytorch_lightning.core.step_result import Result
class LoggerStages(str, Enum):
""" Train/validation/test phase in each training step.
>>> # you can math the type with string
>>> LoggerStages.TRAIN == 'train'
True
"""
TRAIN = "train"
VAL = "validation"
TEST = "test"
@ -35,7 +41,7 @@ class LoggerStages(str, Enum):
raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given")
class ResultStoreType(Enum):
class ResultStoreType(str, Enum):
INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop"
OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop"

View File

@ -0,0 +1,135 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities import rank_zero_warn, DistributedType, DeviceType
class DeprecatedDistDeviceAttributes:
_distrib_type: DistributedType
_device_type: DeviceType
num_gpus: int
@property
def on_cpu(self) -> bool:
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.CPU
@on_cpu.setter
def on_cpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
if val:
self._device_type = DeviceType.CPU
@property
def on_tpu(self) -> bool:
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.TPU
@on_tpu.setter
def on_tpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
# todo add logic that it cannot be set if TPU is missing
if val:
self._device_type = DeviceType.TPU
@property
def use_tpu(self) -> bool:
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.TPU
@use_tpu.setter
def use_tpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
# todo add logic that it cannot be set if TPU is missing
if val:
self._device_type = DeviceType.TPU
@property
def on_gpu(self) -> bool:
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.GPU
@on_gpu.setter
def on_gpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
# todo add logic that it cannot be set if GPU is missing
if val:
self._device_type = DeviceType.GPU
@property
def use_dp(self) -> bool:
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DP
@use_dp.setter
def use_dp(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
if val:
self._distrib_type = DistributedType.DP
@property
def use_ddp(self) -> bool:
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DDP
@use_ddp.setter
def use_ddp(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
if val:
self._distrib_type = DistributedType.DDP
@property
def use_ddp2(self) -> bool:
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DDP2
@use_ddp2.setter
def use_ddp2(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
if val:
self._distrib_type = DistributedType.DDP2
@property
def use_horovod(self) -> bool:
# rank_zero_warn(
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
# )
return self._device_type and self._distrib_type == DistributedType.HOROVOD
@use_horovod.setter
def use_horovod(self, val: bool) -> None:
# rank_zero_warn(
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
# )
if val:
self._distrib_type = DistributedType.HOROVOD
@property
def use_single_gpu(self) -> bool:
# rank_zero_warn(
# "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning,
# )
# todo, limiting to exclude DDP2 is not clear but it comes from connectors...
return (self._device_type and self._device_type == DeviceType.GPU
and self.num_gpus == 1
and self._distrib_type not in (DistributedType.DDP2, ))
@use_single_gpu.setter
def use_single_gpu(self, val: bool) -> None:
# rank_zero_warn(
# "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning,
# )
if val:
self._device_type = DeviceType.GPU

View File

@ -19,9 +19,17 @@ from typing import Callable, Optional
import pytorch_lightning
class TrainerState(Enum):
class TrainerState(str, Enum):
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
to indicate what is currently or was executed. """
to indicate what is currently or was executed.
>>> # you can math the type with string
>>> TrainerState.RUNNING == 'RUNNING'
True
>>> # which is case sensitive
>>> TrainerState.FINISHED == 'finished'
False
"""
INITIALIZING = 'INITIALIZING'
RUNNING = 'RUNNING'
FINISHED = 'FINISHED'

View File

@ -24,11 +24,10 @@ from torch.utils.data import DataLoader
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.plugins.plugin_connector import PluginConnector
@ -53,11 +52,11 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn, DeviceType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -78,6 +77,7 @@ class Trainer(
TrainerLoggingMixin,
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
DeprecatedDistDeviceAttributes,
):
@overwrite_by_env_vars
def __init__(
@ -284,6 +284,8 @@ class Trainer(
handle AMP, TPU, accumulated_gradients, etc..
"""
super().__init__()
self._device_type = DeviceType.CPU
self._distrib_type = None
# init connectors
self.dev_debugger = InternalDebugger(self)

View File

@ -16,6 +16,7 @@ import importlib
import platform
from distutils.version import LooseVersion
from enum import Enum
from typing import Union
import numpy
import torch
@ -66,6 +67,62 @@ FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
class AMPType(Enum):
class LightningEnum(str, Enum):
""" Type of any enumerator with allowed comparison to string invariant to cases. """
@classmethod
def from_str(cls, value: str) -> 'LightningEnum':
statuses = [status for status in dir(cls) if not status.startswith('_')]
for st in statuses:
if st.lower() == value.lower():
return getattr(cls, st)
return None
def __eq__(self, other: Union[str, Enum]) -> bool:
other = other.value if isinstance(other, Enum) else str(other)
return self.value.lower() == other.lower()
class AMPType(LightningEnum):
"""Type of Automatic Mixed Precission used for training.
>>> # you can math the type with string
>>> AMPType.APEX == 'apex'
True
"""
APEX = 'apex'
NATIVE = 'native'
class DistributedType(LightningEnum):
""" Define type of ditributed computing.
>>> # you can math the type with string
>>> DistributedType.DDP == 'ddp'
True
>>> # which is case invariant
>>> DistributedType.DDP2 == 'DDP2'
True
"""
DP = 'dp'
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
HOROVOD = 'horovod'
class DeviceType(LightningEnum):
""" Define Device type byt its nature - acceleatrors.
>>> DeviceType.CPU == DeviceType.from_str('cpu')
True
>>> # you can math the type with string
>>> DeviceType.GPU == 'GPU'
True
>>> # which is case invariant
>>> DeviceType.TPU == 'tpu'
True
"""
CPU = 'CPU'
GPU = 'GPU'
TPU = 'TPU'

View File

@ -1332,15 +1332,17 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
),
],
)
# Todo: mock nb Gpus so all these tests can run on any device
# todo: think about simplification, that the the expected will be just a list use_xxx which shall be true...
def test_trainer_config(trainer_kwargs, expected):
trainer = Trainer(**trainer_kwargs)
assert trainer.use_dp is expected["use_dp"]
assert trainer.use_ddp is expected["use_ddp"]
assert trainer.use_ddp2 is expected["use_ddp2"]
assert trainer.num_gpus == expected["num_gpus"]
assert trainer.on_gpu is expected["on_gpu"]
assert trainer.use_single_gpu is expected["use_single_gpu"]
assert trainer.num_processes == expected["num_processes"]
assert trainer.use_dp is expected["use_dp"], 'for input: %s' % trainer_kwargs
assert trainer.use_ddp is expected["use_ddp"], 'for input: %s' % trainer_kwargs
assert trainer.use_ddp2 is expected["use_ddp2"], 'for input: %s' % trainer_kwargs
assert trainer.num_gpus == expected["num_gpus"], 'for input: %s' % trainer_kwargs
assert trainer.on_gpu is expected["on_gpu"], 'for input: %s' % trainer_kwargs
assert trainer.use_single_gpu is expected["use_single_gpu"], 'for input: %s' % trainer_kwargs
assert trainer.num_processes == expected["num_processes"], 'for input: %s' % trainer_kwargs
def test_trainer_subclassing():