ref: clean config [1/n] add intermediate setters (#4990)
* add intermediate setters * show inputs * fix options * move * fix * less talk * fix * talk less * str * cases * rename Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
068502f07c
commit
ce9179591d
|
@ -256,14 +256,4 @@ class Accelerator(object):
|
|||
yield cm
|
||||
|
||||
|
||||
# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
|
||||
class BackendType(Enum):
|
||||
DP = 'dp'
|
||||
DDP = 'ddp'
|
||||
DDP2 = 'ddp2'
|
||||
DDP_SPAWN = 'ddp_spawn'
|
||||
# decuple distrib and device
|
||||
DDP_CPU = 'ddp_cpu'
|
||||
HOROVOD = 'horovod'
|
||||
# this is rather device
|
||||
TPU = 'tpu'
|
||||
|
||||
|
|
|
@ -335,6 +335,7 @@ class AcceleratorConnector:
|
|||
self.trainer.use_ddp = True
|
||||
self.trainer.data_parallel_device_ids = None
|
||||
self.trainer.on_gpu = False
|
||||
self.trainer.on_cpu = True
|
||||
elif self.trainer.distributed_backend == "horovod":
|
||||
self._set_horovod_backend()
|
||||
|
||||
|
|
|
@ -21,6 +21,12 @@ from pytorch_lightning.core.step_result import Result
|
|||
|
||||
|
||||
class LoggerStages(str, Enum):
|
||||
""" Train/validation/test phase in each training step.
|
||||
|
||||
>>> # you can math the type with string
|
||||
>>> LoggerStages.TRAIN == 'train'
|
||||
True
|
||||
"""
|
||||
TRAIN = "train"
|
||||
VAL = "validation"
|
||||
TEST = "test"
|
||||
|
@ -35,7 +41,7 @@ class LoggerStages(str, Enum):
|
|||
raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given")
|
||||
|
||||
|
||||
class ResultStoreType(Enum):
|
||||
class ResultStoreType(str, Enum):
|
||||
INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop"
|
||||
OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop"
|
||||
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn, DistributedType, DeviceType
|
||||
|
||||
|
||||
class DeprecatedDistDeviceAttributes:
|
||||
|
||||
_distrib_type: DistributedType
|
||||
_device_type: DeviceType
|
||||
num_gpus: int
|
||||
|
||||
@property
|
||||
def on_cpu(self) -> bool:
|
||||
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
return self._device_type and self._device_type == DeviceType.CPU
|
||||
|
||||
@on_cpu.setter
|
||||
def on_cpu(self, val: bool) -> None:
|
||||
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
if val:
|
||||
self._device_type = DeviceType.CPU
|
||||
|
||||
@property
|
||||
def on_tpu(self) -> bool:
|
||||
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
return self._device_type and self._device_type == DeviceType.TPU
|
||||
|
||||
@on_tpu.setter
|
||||
def on_tpu(self, val: bool) -> None:
|
||||
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
# todo add logic that it cannot be set if TPU is missing
|
||||
if val:
|
||||
self._device_type = DeviceType.TPU
|
||||
|
||||
@property
|
||||
def use_tpu(self) -> bool:
|
||||
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
return self._device_type and self._device_type == DeviceType.TPU
|
||||
|
||||
@use_tpu.setter
|
||||
def use_tpu(self, val: bool) -> None:
|
||||
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
# todo add logic that it cannot be set if TPU is missing
|
||||
if val:
|
||||
self._device_type = DeviceType.TPU
|
||||
|
||||
@property
|
||||
def on_gpu(self) -> bool:
|
||||
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
return self._device_type and self._device_type == DeviceType.GPU
|
||||
|
||||
@on_gpu.setter
|
||||
def on_gpu(self, val: bool) -> None:
|
||||
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
# todo add logic that it cannot be set if GPU is missing
|
||||
if val:
|
||||
self._device_type = DeviceType.GPU
|
||||
|
||||
@property
|
||||
def use_dp(self) -> bool:
|
||||
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
return self._device_type and self._distrib_type == DistributedType.DP
|
||||
|
||||
@use_dp.setter
|
||||
def use_dp(self, val: bool) -> None:
|
||||
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
if val:
|
||||
self._distrib_type = DistributedType.DP
|
||||
|
||||
@property
|
||||
def use_ddp(self) -> bool:
|
||||
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
return self._device_type and self._distrib_type == DistributedType.DDP
|
||||
|
||||
@use_ddp.setter
|
||||
def use_ddp(self, val: bool) -> None:
|
||||
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
if val:
|
||||
self._distrib_type = DistributedType.DDP
|
||||
|
||||
@property
|
||||
def use_ddp2(self) -> bool:
|
||||
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
return self._device_type and self._distrib_type == DistributedType.DDP2
|
||||
|
||||
@use_ddp2.setter
|
||||
def use_ddp2(self, val: bool) -> None:
|
||||
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
|
||||
if val:
|
||||
self._distrib_type = DistributedType.DDP2
|
||||
|
||||
@property
|
||||
def use_horovod(self) -> bool:
|
||||
# rank_zero_warn(
|
||||
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
|
||||
# )
|
||||
return self._device_type and self._distrib_type == DistributedType.HOROVOD
|
||||
|
||||
@use_horovod.setter
|
||||
def use_horovod(self, val: bool) -> None:
|
||||
# rank_zero_warn(
|
||||
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
|
||||
# )
|
||||
if val:
|
||||
self._distrib_type = DistributedType.HOROVOD
|
||||
|
||||
@property
|
||||
def use_single_gpu(self) -> bool:
|
||||
# rank_zero_warn(
|
||||
# "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning,
|
||||
# )
|
||||
# todo, limiting to exclude DDP2 is not clear but it comes from connectors...
|
||||
return (self._device_type and self._device_type == DeviceType.GPU
|
||||
and self.num_gpus == 1
|
||||
and self._distrib_type not in (DistributedType.DDP2, ))
|
||||
|
||||
@use_single_gpu.setter
|
||||
def use_single_gpu(self, val: bool) -> None:
|
||||
# rank_zero_warn(
|
||||
# "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning,
|
||||
# )
|
||||
if val:
|
||||
self._device_type = DeviceType.GPU
|
|
@ -19,9 +19,17 @@ from typing import Callable, Optional
|
|||
import pytorch_lightning
|
||||
|
||||
|
||||
class TrainerState(Enum):
|
||||
class TrainerState(str, Enum):
|
||||
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
to indicate what is currently or was executed. """
|
||||
to indicate what is currently or was executed.
|
||||
|
||||
>>> # you can math the type with string
|
||||
>>> TrainerState.RUNNING == 'RUNNING'
|
||||
True
|
||||
>>> # which is case sensitive
|
||||
>>> TrainerState.FINISHED == 'finished'
|
||||
False
|
||||
"""
|
||||
INITIALIZING = 'INITIALIZING'
|
||||
RUNNING = 'RUNNING'
|
||||
FINISHED = 'FINISHED'
|
||||
|
|
|
@ -24,11 +24,10 @@ from torch.utils.data import DataLoader
|
|||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
|
||||
from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
||||
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.plugins.plugin_connector import PluginConnector
|
||||
|
@ -53,11 +52,11 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
|||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
from pytorch_lightning.trainer.properties import TrainerProperties
|
||||
from pytorch_lightning.trainer.states import TrainerState, trainer_state
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.trainer.training_loop import TrainLoop
|
||||
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities import rank_zero_warn, DeviceType
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -78,6 +77,7 @@ class Trainer(
|
|||
TrainerLoggingMixin,
|
||||
TrainerTrainingTricksMixin,
|
||||
TrainerDataLoadingMixin,
|
||||
DeprecatedDistDeviceAttributes,
|
||||
):
|
||||
@overwrite_by_env_vars
|
||||
def __init__(
|
||||
|
@ -284,6 +284,8 @@ class Trainer(
|
|||
handle AMP, TPU, accumulated_gradients, etc..
|
||||
"""
|
||||
super().__init__()
|
||||
self._device_type = DeviceType.CPU
|
||||
self._distrib_type = None
|
||||
|
||||
# init connectors
|
||||
self.dev_debugger = InternalDebugger(self)
|
||||
|
|
|
@ -16,6 +16,7 @@ import importlib
|
|||
import platform
|
||||
from distutils.version import LooseVersion
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
@ -66,6 +67,62 @@ FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
|
|||
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
|
||||
|
||||
|
||||
class AMPType(Enum):
|
||||
class LightningEnum(str, Enum):
|
||||
""" Type of any enumerator with allowed comparison to string invariant to cases. """
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, value: str) -> 'LightningEnum':
|
||||
statuses = [status for status in dir(cls) if not status.startswith('_')]
|
||||
for st in statuses:
|
||||
if st.lower() == value.lower():
|
||||
return getattr(cls, st)
|
||||
return None
|
||||
|
||||
def __eq__(self, other: Union[str, Enum]) -> bool:
|
||||
other = other.value if isinstance(other, Enum) else str(other)
|
||||
return self.value.lower() == other.lower()
|
||||
|
||||
|
||||
class AMPType(LightningEnum):
|
||||
"""Type of Automatic Mixed Precission used for training.
|
||||
|
||||
>>> # you can math the type with string
|
||||
>>> AMPType.APEX == 'apex'
|
||||
True
|
||||
"""
|
||||
APEX = 'apex'
|
||||
NATIVE = 'native'
|
||||
|
||||
|
||||
class DistributedType(LightningEnum):
|
||||
""" Define type of ditributed computing.
|
||||
|
||||
>>> # you can math the type with string
|
||||
>>> DistributedType.DDP == 'ddp'
|
||||
True
|
||||
>>> # which is case invariant
|
||||
>>> DistributedType.DDP2 == 'DDP2'
|
||||
True
|
||||
"""
|
||||
DP = 'dp'
|
||||
DDP = 'ddp'
|
||||
DDP2 = 'ddp2'
|
||||
DDP_SPAWN = 'ddp_spawn'
|
||||
HOROVOD = 'horovod'
|
||||
|
||||
|
||||
class DeviceType(LightningEnum):
|
||||
""" Define Device type byt its nature - acceleatrors.
|
||||
|
||||
>>> DeviceType.CPU == DeviceType.from_str('cpu')
|
||||
True
|
||||
>>> # you can math the type with string
|
||||
>>> DeviceType.GPU == 'GPU'
|
||||
True
|
||||
>>> # which is case invariant
|
||||
>>> DeviceType.TPU == 'tpu'
|
||||
True
|
||||
"""
|
||||
CPU = 'CPU'
|
||||
GPU = 'GPU'
|
||||
TPU = 'TPU'
|
||||
|
|
|
@ -1332,15 +1332,17 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
|||
),
|
||||
],
|
||||
)
|
||||
# Todo: mock nb Gpus so all these tests can run on any device
|
||||
# todo: think about simplification, that the the expected will be just a list use_xxx which shall be true...
|
||||
def test_trainer_config(trainer_kwargs, expected):
|
||||
trainer = Trainer(**trainer_kwargs)
|
||||
assert trainer.use_dp is expected["use_dp"]
|
||||
assert trainer.use_ddp is expected["use_ddp"]
|
||||
assert trainer.use_ddp2 is expected["use_ddp2"]
|
||||
assert trainer.num_gpus == expected["num_gpus"]
|
||||
assert trainer.on_gpu is expected["on_gpu"]
|
||||
assert trainer.use_single_gpu is expected["use_single_gpu"]
|
||||
assert trainer.num_processes == expected["num_processes"]
|
||||
assert trainer.use_dp is expected["use_dp"], 'for input: %s' % trainer_kwargs
|
||||
assert trainer.use_ddp is expected["use_ddp"], 'for input: %s' % trainer_kwargs
|
||||
assert trainer.use_ddp2 is expected["use_ddp2"], 'for input: %s' % trainer_kwargs
|
||||
assert trainer.num_gpus == expected["num_gpus"], 'for input: %s' % trainer_kwargs
|
||||
assert trainer.on_gpu is expected["on_gpu"], 'for input: %s' % trainer_kwargs
|
||||
assert trainer.use_single_gpu is expected["use_single_gpu"], 'for input: %s' % trainer_kwargs
|
||||
assert trainer.num_processes == expected["num_processes"], 'for input: %s' % trainer_kwargs
|
||||
|
||||
|
||||
def test_trainer_subclassing():
|
||||
|
|
Loading…
Reference in New Issue