From 3976db597d6cb86d0f99367554adbc874354b486 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 4 Dec 2020 10:26:10 +0100 Subject: [PATCH] refactor imports of optional dependencies (#4859) * refactor imports of optional dependencies * fix * fix * fix * fix * fix * flake8 * flake8 Co-authored-by: Sean Naren Co-authored-by: chaton --- .../accelerators/accelerator_connector.py | 9 ++-- .../accelerators/horovod_accelerator.py | 8 +--- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/trainer/data_loading.py | 10 ++--- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/debugging.py | 2 +- .../data/horovod/train_default_model.py | 13 +++--- tests/models/test_horovod.py | 44 +++++++------------ 8 files changed, 33 insertions(+), 56 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index a22a8fb370..9d36f76876 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -15,21 +15,18 @@ import os import torch +from pytorch_lightning.utilities import HOROVOD_AVAILABLE from pytorch_lightning import _logger as log from pytorch_lightning import accelerators from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment -from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only, TPU_AVAILABLE +from pytorch_lightning.utilities import device_parser, rank_zero_only, TPU_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -try: +if HOROVOD_AVAILABLE: import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True class AcceleratorConnector: diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index b2cec90617..460f5a83d2 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -18,15 +18,11 @@ import torch from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only -try: +if HOROVOD_AVAILABLE: import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True class HorovodAccelerator(Accelerator): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 78fc740e38..8ea6f9d0c8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1670,7 +1670,7 @@ class LightningModule( line = re.sub(r"\s+", "", line, flags=re.UNICODE) if ".hparams=" in line: return line.split("=")[1] - except Exception as e: + except Exception: return "hparams" return None diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a15e9bba2a..4a7b14d0b1 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,14 +16,14 @@ import multiprocessing import platform from abc import ABC from copy import deepcopy -from typing import Callable, Iterable, List, Optional, Tuple, Union +from typing import Union, List, Tuple, Callable, Optional, Iterable from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE, HOROVOD_AVAILABLE from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -32,12 +32,8 @@ from pytorch_lightning.utilities.model_utils import is_overridden if TPU_AVAILABLE: import torch_xla.core.xla_model as xm -try: +if HOROVOD_AVAILABLE: import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True class TrainerDataLoadingMixin(ABC): diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 916e434e5f..1e2eeea9f4 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -47,6 +47,7 @@ APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") OMEGACONF_AVAILABLE = _module_available("omegaconf") HYDRA_AVAILABLE = _module_available("hydra") +HOROVOD_AVAILABLE = _module_available("horovod.torch") TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index f7b9e79b7f..9264e2a498 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -87,7 +87,7 @@ class InternalDebugger(object): for dl in dataloaders: try: length = len(dl) - except Exception as e: + except Exception: length = -1 lengths.append(length) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 5b31c67817..94daaedb4f 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -21,18 +21,19 @@ import json import os import sys - -try: - import horovod.torch as hvd -except ImportError: - print('You requested to import Horovod which is missing or not supported for your OS.') - PATH_HERE = os.path.abspath(os.path.dirname(__file__)) PATH_ROOT = os.path.abspath(os.path.join(PATH_HERE, '..', '..', '..', '..')) sys.path.insert(0, os.path.abspath(PATH_ROOT)) from pytorch_lightning import Trainer # noqa: E402 from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402 +from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402 + +if HOROVOD_AVAILABLE: + import horovod.torch as hvd # noqa: E402 +else: + print('You requested to import Horovod which is missing or not supported for your OS.') + # Move project root to the front of the search path, as some imports may have reordered things idx = sys.path.index(PATH_ROOT) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 0fc68a226e..1a38b12d37 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -29,33 +29,25 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator from pytorch_lightning.core.step_result import EvalResult, Result, TrainResult from pytorch_lightning.metrics.classification.accuracy import Accuracy -from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, HOROVOD_AVAILABLE, _module_available from tests.base import EvalModelTemplate +from tests.base.boring_model import BoringModel from tests.base.models import BasicGAN -try: +if HOROVOD_AVAILABLE: import horovod - from horovod.common.util import nccl_built -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - + import horovod.torch as hvd # This script will run the actual test model training in parallel TEST_SCRIPT = os.path.join(os.path.dirname(__file__), 'data', 'horovod', 'train_default_model.py') - -def _nccl_available(): - if not HOROVOD_AVAILABLE: - return False - - try: - return nccl_built() - except AttributeError: - # Horovod 0.19.1 nccl_built() does not yet work with Python 3.8: - # See: https://github.com/horovod/horovod/issues/1891 - return False +try: + from horovod.common.util import nccl_built + nccl_built() +except (ImportError, ModuleNotFoundError, AttributeError): + HOROVOD_NCCL_AVAILABLE = False +finally: + HOROVOD_NCCL_AVAILABLE = True def _run_horovod(trainer_options, on_gpu=False): @@ -114,7 +106,7 @@ def test_horovod_cpu_implicit(enable_pl_optimizer, tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_horovod_multi_gpu(tmpdir): """Test Horovod with multi-GPU support.""" @@ -134,7 +126,7 @@ def test_horovod_multi_gpu(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex") def test_horovod_apex(tmpdir): @@ -158,7 +150,7 @@ def test_horovod_apex(tmpdir): @pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp") def test_horovod_amp(tmpdir): @@ -181,7 +173,7 @@ def test_horovod_amp(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_horovod_transfer_batch_to_gpu(tmpdir): class TestTrainingStepModel(EvalModelTemplate): @@ -263,10 +255,6 @@ def test_result_reduce_horovod(enable_pl_optimizer, tmpdir): path_root = os.path.abspath(os.path.join(path_here, '..', '..')) sys.path.insert(0, os.path.abspath(path_root)) - import horovod.torch as hvd - - from tests.base.boring_model import BoringModel - class TestModel(BoringModel): def training_step(self, batch, batch_idx): self.training_step_called = True @@ -318,8 +306,6 @@ def test_accuracy_metric_horovod(): target = torch.randint(high=2, size=(num_batches, batch_size)) def _compute_batch(): - import horovod.torch as hvd - trainer = Trainer( fast_dev_run=True, distributed_backend='horovod',