From ecd3678a1fdf234ac47dba3b56784a3c7467d83d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 11 Feb 2021 00:32:47 +0100 Subject: [PATCH] Refactor utilities/imports.py (#5874) Co-authored-by: Jirka Borovec Co-authored-by: Nicki Skafte --- pytorch_lightning/utilities/imports.py | 49 ++++++++++++++------------ 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index eb9b924c70..aa86d560b6 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -12,50 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" -import importlib import platform from distutils.version import LooseVersion +from importlib.util import find_spec import pkg_resources import torch def _module_available(module_path: str) -> bool: - """Testing if given module is avalaible in your env + """ + Check if a path is available in your environment >>> _module_available('os') True >>> _module_available('bla.bla') False """ - # todo: find a better way than try / except try: - mods = module_path.split('.') - assert mods, 'nothing given to test' - # it has to be tested as per partets - for i in range(len(mods)): - module_path = '.'.join(mods[:i + 1]) - if importlib.util.find_spec(module_path) is None: - return False - return True + return find_spec(module_path) is not None except AttributeError: + # Python 3.6 + return False + except ModuleNotFoundError: + # Python 3.7+ return False +def _get_version(package: str) -> LooseVersion: + return LooseVersion(pkg_resources.get_distribution(package).version) + + +_IS_WINDOWS = platform.system() == "Windows" +_TORCH_GREATER_EQUAL_1_6 = _get_version("torch") >= LooseVersion("1.6.0") + _APEX_AVAILABLE = _module_available("apex.amp") -_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") -_OMEGACONF_AVAILABLE = _module_available("omegaconf") +_BOLTS_AVAILABLE = _module_available('pl_bolts') +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') +_FAIRSCALE_PIPE_AVAILABLE = ( + _FAIRSCALE_AVAILABLE and _TORCH_GREATER_EQUAL_1_6 and _get_version('fairscale') <= LooseVersion("0.1.3") +) +_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') +_HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _module_available("hydra") _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") -_HOROVOD_AVAILABLE = _module_available("horovod.torch") +_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") +_OMEGACONF_AVAILABLE = _module_available("omegaconf") +_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc') _TORCHTEXT_AVAILABLE = _module_available("torchtext") -_XLA_AVAILABLE = _module_available("torch_xla") -_FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') -_RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc') -_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group') -_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion( - torch.__version__ -) >= LooseVersion("1.6.0") and LooseVersion(pkg_resources.get_distribution('fairscale').version - ) <= LooseVersion("0.1.3") -_BOLTS_AVAILABLE = _module_available('pl_bolts') _TORCHVISION_AVAILABLE = _module_available('torchvision') +_XLA_AVAILABLE = _module_available("torch_xla")