From ae9956f997cb8e66d8c28951d47c7106fca54c1b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 29 Dec 2020 19:02:18 +0100 Subject: [PATCH] clean avail. imports & enums (#5256) * clean avail. imports * enums * fix missing --- pytorch_lightning/utilities/__init__.py | 123 ++++------------------ pytorch_lightning/utilities/apply_func.py | 10 +- pytorch_lightning/utilities/enums.py | 77 ++++++++++++++ pytorch_lightning/utilities/imports.py | 55 ++++++++++ pytorch_lightning/utilities/xla_device.py | 3 +- 5 files changed, 160 insertions(+), 108 deletions(-) create mode 100644 pytorch_lightning/utilities/enums.py create mode 100644 pytorch_lightning/utilities/imports.py diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 4be6195007..adcf996c66 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -12,14 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" -import importlib -import platform -from distutils.version import LooseVersion -from enum import Enum -from typing import Union import numpy -import torch from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 @@ -28,106 +22,33 @@ from pytorch_lightning.utilities.distributed import ( # noqa: F401 rank_zero_only, rank_zero_warn, ) +from pytorch_lightning.utilities.enums import ( # noqa: F401 + LightningEnum, + AMPType, + DistributedType, + DeviceType, +) +from pytorch_lightning.utilities.imports import ( # noqa: F401 + _APEX_AVAILABLE, + _NATIVE_AMP_AVAILABLE, + _XLA_AVAILABLE, + _OMEGACONF_AVAILABLE, + _HYDRA_AVAILABLE, + _HOROVOD_AVAILABLE, + _TORCHTEXT_AVAILABLE, + _FAIRSCALE_AVAILABLE, + _RPC_AVAILABLE, + _GROUP_AVAILABLE, + _FAIRSCALE_PIPE_AVAILABLE, + _BOLTS_AVAILABLE, + _module_available, +) from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 -from pytorch_lightning.utilities.xla_device import _XLA_AVAILABLE, XLADeviceUtils # noqa: F401 +from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401 -def _module_available(module_path: str) -> bool: - """Testing if given module is avalaible in your env - - >>> _module_available('os') - True - >>> _module_available('bla.bla') - False - """ - # todo: find a better way than try / except - try: - mods = module_path.split('.') - assert mods, 'nothing given to test' - # it has to be tested as per partets - for i in range(len(mods)): - module_path = '.'.join(mods[:i + 1]) - if importlib.util.find_spec(module_path) is None: - return False - return True - except AttributeError: - return False - - -_APEX_AVAILABLE = _module_available("apex.amp") -_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") -_OMEGACONF_AVAILABLE = _module_available("omegaconf") -_HYDRA_AVAILABLE = _module_available("hydra") -_HOROVOD_AVAILABLE = _module_available("horovod.torch") - _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() -_FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') -_RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc') -_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group') -_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0") -_BOLTS_AVAILABLE = _module_available('pl_bolts') FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps - - -class LightningEnum(str, Enum): - """ Type of any enumerator with allowed comparison to string invariant to cases. """ - - @classmethod - def from_str(cls, value: str) -> 'LightningEnum': - statuses = [status for status in dir(cls) if not status.startswith('_')] - for st in statuses: - if st.lower() == value.lower(): - return getattr(cls, st) - return None - - def __eq__(self, other: Union[str, Enum]) -> bool: - other = other.value if isinstance(other, Enum) else str(other) - return self.value.lower() == other.lower() - - -class AMPType(LightningEnum): - """Type of Automatic Mixed Precission used for training. - - >>> # you can math the type with string - >>> AMPType.APEX == 'apex' - True - """ - APEX = 'apex' - NATIVE = 'native' - - -class DistributedType(LightningEnum): - """ Define type of ditributed computing. - - >>> # you can math the type with string - >>> DistributedType.DDP == 'ddp' - True - >>> # which is case invariant - >>> DistributedType.DDP2 == 'DDP2' - True - """ - DP = 'dp' - DDP = 'ddp' - DDP2 = 'ddp2' - DDP_SPAWN = 'ddp_spawn' - HOROVOD = 'horovod' - - -class DeviceType(LightningEnum): - """ Define Device type byt its nature - acceleatrors. - - >>> DeviceType.CPU == DeviceType.from_str('cpu') - True - >>> # you can math the type with string - >>> DeviceType.GPU == 'GPU' - True - >>> # which is case invariant - >>> DeviceType.TPU == 'tpu' - True - """ - CPU = 'CPU' - GPU = 'GPU' - TPU = 'TPU' diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 775c22dbbf..afcc4dd36f 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib from abc import ABC from collections.abc import Mapping, Sequence from copy import copy @@ -20,8 +19,9 @@ from typing import Any, Callable, Union import torch -TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None -if TORCHTEXT_AVAILABLE: +from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE + +if _TORCHTEXT_AVAILABLE: from torchtext.data import Batch else: Batch = type(None) @@ -109,7 +109,7 @@ def move_data_to_device(batch: Any, device: torch.device): def batch_to(data): # try to move torchtext data first - if TORCHTEXT_AVAILABLE and isinstance(data, Batch): + if _TORCHTEXT_AVAILABLE and isinstance(data, Batch): # Shallow copy because each Batch has a reference to Dataset which contains all examples device_data = copy(data) @@ -123,5 +123,5 @@ def move_data_to_device(batch: Any, device: torch.device): kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {} return data.to(device, **kwargs) - dtype = (TransferableDataType, Batch) if TORCHTEXT_AVAILABLE else TransferableDataType + dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType return apply_to_collection(batch, dtype=dtype, function=batch_to) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py new file mode 100644 index 0000000000..bbcb83b6ee --- /dev/null +++ b/pytorch_lightning/utilities/enums.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Enumerated utilities""" +from enum import Enum +from typing import Union + + +class LightningEnum(str, Enum): + """ Type of any enumerator with allowed comparison to string invariant to cases. """ + + @classmethod + def from_str(cls, value: str) -> 'LightningEnum': + statuses = [status for status in dir(cls) if not status.startswith('_')] + for st in statuses: + if st.lower() == value.lower(): + return getattr(cls, st) + return None + + def __eq__(self, other: Union[str, Enum]) -> bool: + other = other.value if isinstance(other, Enum) else str(other) + return self.value.lower() == other.lower() + + +class AMPType(LightningEnum): + """Type of Automatic Mixed Precission used for training. + + >>> # you can math the type with string + >>> AMPType.APEX == 'apex' + True + """ + APEX = 'apex' + NATIVE = 'native' + + +class DistributedType(LightningEnum): + """ Define type of ditributed computing. + + >>> # you can math the type with string + >>> DistributedType.DDP == 'ddp' + True + >>> # which is case invariant + >>> DistributedType.DDP2 == 'DDP2' + True + """ + DP = 'dp' + DDP = 'ddp' + DDP2 = 'ddp2' + DDP_SPAWN = 'ddp_spawn' + HOROVOD = 'horovod' + + +class DeviceType(LightningEnum): + """ Define Device type byt its nature - acceleatrors. + + >>> DeviceType.CPU == DeviceType.from_str('cpu') + True + >>> # you can math the type with string + >>> DeviceType.GPU == 'GPU' + True + >>> # which is case invariant + >>> DeviceType.TPU == 'tpu' + True + """ + CPU = 'CPU' + GPU = 'GPU' + TPU = 'TPU' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py new file mode 100644 index 0000000000..acdebfbf23 --- /dev/null +++ b/pytorch_lightning/utilities/imports.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""General utilities""" +import importlib +import platform +from distutils.version import LooseVersion + +import torch + + +def _module_available(module_path: str) -> bool: + """Testing if given module is avalaible in your env + + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + """ + # todo: find a better way than try / except + try: + mods = module_path.split('.') + assert mods, 'nothing given to test' + # it has to be tested as per partets + for i in range(len(mods)): + module_path = '.'.join(mods[:i + 1]) + if importlib.util.find_spec(module_path) is None: + return False + return True + except AttributeError: + return False + + +_APEX_AVAILABLE = _module_available("apex.amp") +_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") +_OMEGACONF_AVAILABLE = _module_available("omegaconf") +_HYDRA_AVAILABLE = _module_available("hydra") +_HOROVOD_AVAILABLE = _module_available("horovod.torch") +_TORCHTEXT_AVAILABLE = _module_available("torchtext") +_XLA_AVAILABLE = _module_available("torch_xla") +_FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') +_RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc') +_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group') +_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0") +_BOLTS_AVAILABLE = _module_available('pl_bolts') diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index d7702aef33..f220ea6414 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import importlib import queue as q import traceback from multiprocessing import Process, Queue import torch -_XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None +from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm