From 810643bca2cf74a83d51e3cc02ef97cd1656ee5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 20 Sep 2022 12:19:51 +0200 Subject: [PATCH] Surface Neptune installation problems to the user (#14715) --- src/pytorch_lightning/loggers/neptune.py | 7 ++----- tests/tests_pytorch/loggers/test_all.py | 5 ++++- tests/tests_pytorch/loggers/test_neptune.py | 5 +++++ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/loggers/neptune.py b/src/pytorch_lightning/loggers/neptune.py index a5df22d8be..6c4c78d37d 100644 --- a/src/pytorch_lightning/loggers/neptune.py +++ b/src/pytorch_lightning/loggers/neptune.py @@ -231,11 +231,8 @@ class NeptuneLogger(Logger): agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, **neptune_run_kwargs: Any, ): - if neptune is None: - raise ModuleNotFoundError( - "You want to use the `Neptune` logger which is not installed yet, install it with" - " `pip install neptune-client`." - ) + if not _NEPTUNE_AVAILABLE: + raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE)) # verify if user passed proper init arguments self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs) super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 279a1aeab7..8d79442e68 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -43,6 +43,7 @@ LOGGER_CTX_MANAGERS = ( mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"), mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock), + mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True), mock.patch("pytorch_lightning.loggers.wandb.wandb"), mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock), ) @@ -290,7 +291,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0) # Neptune - with mock.patch("pytorch_lightning.loggers.neptune.neptune"): + with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch( + "pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True + ): logger = _instantiate_logger(NeptuneLogger, api_key="test", project="project", save_dir=tmpdir, prefix=prefix) assert logger.experiment.__getitem__.call_count == 2 logger.log_metrics({"test": 1.0}, step=0) diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index dd7fa60fd9..0dc0347e75 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -15,6 +15,7 @@ import os import pickle import unittest from collections import namedtuple +from unittest import mock from unittest.mock import call, MagicMock, patch import pytest @@ -78,6 +79,10 @@ def tmpdir_unittest_fixture(request, tmpdir): @patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock) class TestNeptuneLogger(unittest.TestCase): + def run(self, *args, **kwargs): + with mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True): + super().run(*args, **kwargs) + def test_neptune_online(self, neptune): logger = NeptuneLogger(api_key="test", project="project") created_run_mock = logger.run