diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 5a10f21d21..b5301dd686 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -256,14 +256,4 @@ class Accelerator(object): yield cm -# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos... -class BackendType(Enum): - DP = 'dp' - DDP = 'ddp' - DDP2 = 'ddp2' - DDP_SPAWN = 'ddp_spawn' - # decuple distrib and device - DDP_CPU = 'ddp_cpu' - HOROVOD = 'horovod' - # this is rather device - TPU = 'tpu' + diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 1436e37dbd..4d899da2b0 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -335,6 +335,7 @@ class AcceleratorConnector: self.trainer.use_ddp = True self.trainer.data_parallel_device_ids = None self.trainer.on_gpu = False + self.trainer.on_cpu = True elif self.trainer.distributed_backend == "horovod": self._set_horovod_backend() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 8dc993df7d..2802585981 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -21,6 +21,12 @@ from pytorch_lightning.core.step_result import Result class LoggerStages(str, Enum): + """ Train/validation/test phase in each training step. + + >>> # you can math the type with string + >>> LoggerStages.TRAIN == 'train' + True + """ TRAIN = "train" VAL = "validation" TEST = "test" @@ -35,7 +41,7 @@ class LoggerStages(str, Enum): raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given") -class ResultStoreType(Enum): +class ResultStoreType(str, Enum): INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop" diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py new file mode 100644 index 0000000000..b7e9fe1526 --- /dev/null +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -0,0 +1,135 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities import rank_zero_warn, DistributedType, DeviceType + + +class DeprecatedDistDeviceAttributes: + + _distrib_type: DistributedType + _device_type: DeviceType + num_gpus: int + + @property + def on_cpu(self) -> bool: + # rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + return self._device_type and self._device_type == DeviceType.CPU + + @on_cpu.setter + def on_cpu(self, val: bool) -> None: + # rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + if val: + self._device_type = DeviceType.CPU + + @property + def on_tpu(self) -> bool: + # rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + return self._device_type and self._device_type == DeviceType.TPU + + @on_tpu.setter + def on_tpu(self, val: bool) -> None: + # rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + # todo add logic that it cannot be set if TPU is missing + if val: + self._device_type = DeviceType.TPU + + @property + def use_tpu(self) -> bool: + # rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + return self._device_type and self._device_type == DeviceType.TPU + + @use_tpu.setter + def use_tpu(self, val: bool) -> None: + # rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + # todo add logic that it cannot be set if TPU is missing + if val: + self._device_type = DeviceType.TPU + + @property + def on_gpu(self) -> bool: + # rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + return self._device_type and self._device_type == DeviceType.GPU + + @on_gpu.setter + def on_gpu(self, val: bool) -> None: + # rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + # todo add logic that it cannot be set if GPU is missing + if val: + self._device_type = DeviceType.GPU + + @property + def use_dp(self) -> bool: + # rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + return self._device_type and self._distrib_type == DistributedType.DP + + @use_dp.setter + def use_dp(self, val: bool) -> None: + # rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + if val: + self._distrib_type = DistributedType.DP + + @property + def use_ddp(self) -> bool: + # rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + return self._device_type and self._distrib_type == DistributedType.DDP + + @use_ddp.setter + def use_ddp(self, val: bool) -> None: + # rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + if val: + self._distrib_type = DistributedType.DDP + + @property + def use_ddp2(self) -> bool: + # rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + return self._device_type and self._distrib_type == DistributedType.DDP2 + + @use_ddp2.setter + def use_ddp2(self, val: bool) -> None: + # rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + if val: + self._distrib_type = DistributedType.DDP2 + + @property + def use_horovod(self) -> bool: + # rank_zero_warn( + # "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning + # ) + return self._device_type and self._distrib_type == DistributedType.HOROVOD + + @use_horovod.setter + def use_horovod(self, val: bool) -> None: + # rank_zero_warn( + # "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning + # ) + if val: + self._distrib_type = DistributedType.HOROVOD + + @property + def use_single_gpu(self) -> bool: + # rank_zero_warn( + # "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning, + # ) + # todo, limiting to exclude DDP2 is not clear but it comes from connectors... + return (self._device_type and self._device_type == DeviceType.GPU + and self.num_gpus == 1 + and self._distrib_type not in (DistributedType.DDP2, )) + + @use_single_gpu.setter + def use_single_gpu(self, val: bool) -> None: + # rank_zero_warn( + # "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning, + # ) + if val: + self._device_type = DeviceType.GPU diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index c99c6b3644..6557e6f870 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -19,9 +19,17 @@ from typing import Callable, Optional import pytorch_lightning -class TrainerState(Enum): +class TrainerState(str, Enum): """ State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer` - to indicate what is currently or was executed. """ + to indicate what is currently or was executed. + + >>> # you can math the type with string + >>> TrainerState.RUNNING == 'RUNNING' + True + >>> # which is case sensitive + >>> TrainerState.FINISHED == 'finished' + False + """ INITIALIZING = 'INITIALIZING' RUNNING = 'RUNNING' FINISHED = 'FINISHED' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ccb9f9418c..31a64d00cc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -24,11 +24,10 @@ from torch.utils.data import DataLoader from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector -from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator -from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes +from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.plugins.plugin_connector import PluginConnector @@ -53,11 +52,11 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerState, trainer_state +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, DeviceType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -78,6 +77,7 @@ class Trainer( TrainerLoggingMixin, TrainerTrainingTricksMixin, TrainerDataLoadingMixin, + DeprecatedDistDeviceAttributes, ): @overwrite_by_env_vars def __init__( @@ -284,6 +284,8 @@ class Trainer( handle AMP, TPU, accumulated_gradients, etc.. """ super().__init__() + self._device_type = DeviceType.CPU + self._distrib_type = None # init connectors self.dev_debugger = InternalDebugger(self) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 90d2ca0acc..a11862b400 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -16,6 +16,7 @@ import importlib import platform from distutils.version import LooseVersion from enum import Enum +from typing import Union import numpy import torch @@ -66,6 +67,62 @@ FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps -class AMPType(Enum): +class LightningEnum(str, Enum): + """ Type of any enumerator with allowed comparison to string invariant to cases. """ + + @classmethod + def from_str(cls, value: str) -> 'LightningEnum': + statuses = [status for status in dir(cls) if not status.startswith('_')] + for st in statuses: + if st.lower() == value.lower(): + return getattr(cls, st) + return None + + def __eq__(self, other: Union[str, Enum]) -> bool: + other = other.value if isinstance(other, Enum) else str(other) + return self.value.lower() == other.lower() + + +class AMPType(LightningEnum): + """Type of Automatic Mixed Precission used for training. + + >>> # you can math the type with string + >>> AMPType.APEX == 'apex' + True + """ APEX = 'apex' NATIVE = 'native' + + +class DistributedType(LightningEnum): + """ Define type of ditributed computing. + + >>> # you can math the type with string + >>> DistributedType.DDP == 'ddp' + True + >>> # which is case invariant + >>> DistributedType.DDP2 == 'DDP2' + True + """ + DP = 'dp' + DDP = 'ddp' + DDP2 = 'ddp2' + DDP_SPAWN = 'ddp_spawn' + HOROVOD = 'horovod' + + +class DeviceType(LightningEnum): + """ Define Device type byt its nature - acceleatrors. + + >>> DeviceType.CPU == DeviceType.from_str('cpu') + True + >>> # you can math the type with string + >>> DeviceType.GPU == 'GPU' + True + >>> # which is case invariant + >>> DeviceType.TPU == 'tpu' + True + """ + CPU = 'CPU' + GPU = 'GPU' + TPU = 'TPU' diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b7db30e398..427956fb89 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1332,15 +1332,17 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): ), ], ) +# Todo: mock nb Gpus so all these tests can run on any device +# todo: think about simplification, that the the expected will be just a list use_xxx which shall be true... def test_trainer_config(trainer_kwargs, expected): trainer = Trainer(**trainer_kwargs) - assert trainer.use_dp is expected["use_dp"] - assert trainer.use_ddp is expected["use_ddp"] - assert trainer.use_ddp2 is expected["use_ddp2"] - assert trainer.num_gpus == expected["num_gpus"] - assert trainer.on_gpu is expected["on_gpu"] - assert trainer.use_single_gpu is expected["use_single_gpu"] - assert trainer.num_processes == expected["num_processes"] + assert trainer.use_dp is expected["use_dp"], 'for input: %s' % trainer_kwargs + assert trainer.use_ddp is expected["use_ddp"], 'for input: %s' % trainer_kwargs + assert trainer.use_ddp2 is expected["use_ddp2"], 'for input: %s' % trainer_kwargs + assert trainer.num_gpus == expected["num_gpus"], 'for input: %s' % trainer_kwargs + assert trainer.on_gpu is expected["on_gpu"], 'for input: %s' % trainer_kwargs + assert trainer.use_single_gpu is expected["use_single_gpu"], 'for input: %s' % trainer_kwargs + assert trainer.num_processes == expected["num_processes"], 'for input: %s' % trainer_kwargs def test_trainer_subclassing():