# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pickle import unittest from collections import namedtuple from unittest.mock import call, MagicMock, patch import pytest import torch from pytorch_lightning import __version__, Trainer from pytorch_lightning.loggers import NeptuneLogger from tests.helpers import BoringModel def create_neptune_mock(): """Mock with provides nice `logger.name` and `logger.version` values. Mostly due to fact, that windows tests were failing with MagicMock based strings, which were used to create local directories in FS. """ return MagicMock( init=MagicMock( return_value=MagicMock( __getitem__=MagicMock(return_value=MagicMock(fetch=MagicMock(return_value="Run test name"))), _short_id="TEST-1", ) ) ) class Run: _short_id = "TEST-42" _project_name = "test-project" def __setitem__(self, key, value): # called once assert key == "source_code/integrations/pytorch-lightning" assert value == __version__ def wait(self): # for test purposes pass def __getitem__(self, item): # called once assert item == "sys/name" return MagicMock(fetch=MagicMock(return_value="Test name")) def __getstate__(self): raise pickle.PicklingError("Runs are unpickleable") @pytest.fixture def tmpdir_unittest_fixture(request, tmpdir): """Proxy for pytest `tmpdir` fixture between pytest and unittest. Resources: * https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-fixture * https://towardsdatascience.com/mixing-pytest-fixture-and-unittest-testcase-for-selenium-test-9162218e8c8e """ request.cls.tmpdir = tmpdir @patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock) class TestNeptuneLogger(unittest.TestCase): def test_neptune_online(self, neptune): logger = NeptuneLogger(api_key="test", project="project") created_run_mock = logger._run_instance self.assertEqual(logger._run_instance, created_run_mock) self.assertEqual(logger.name, "Run test name") self.assertEqual(logger.version, "TEST-1") self.assertEqual(neptune.init.call_count, 1) self.assertEqual(created_run_mock.__getitem__.call_count, 1) self.assertEqual(created_run_mock.__setitem__.call_count, 1) created_run_mock.__getitem__.assert_called_once_with( "sys/name", ) created_run_mock.__setitem__.assert_called_once_with("source_code/integrations/pytorch-lightning", __version__) @patch("pytorch_lightning.loggers.neptune.Run", Run) def test_online_with_custom_run(self, neptune): created_run = Run() logger = NeptuneLogger(run=created_run) assert logger._run_instance == created_run self.assertEqual(logger._run_instance, created_run) self.assertEqual(logger.version, created_run._short_id) self.assertEqual(neptune.init.call_count, 0) @patch("pytorch_lightning.loggers.neptune.Run", Run) def test_neptune_pickling(self, neptune): unpickleable_run = Run() logger = NeptuneLogger(run=unpickleable_run) self.assertEqual(0, neptune.init.call_count) pickled_logger = pickle.dumps(logger) unpickled = pickle.loads(pickled_logger) neptune.init.assert_called_once_with(project="test-project", api_token=None, run="TEST-42") self.assertIsNotNone(unpickled.experiment) @patch("pytorch_lightning.loggers.neptune.Run", Run) def test_online_with_wrong_kwargs(self, neptune): """Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable in init.""" with self.assertRaises(ValueError): NeptuneLogger(run="some string") with self.assertRaises(ValueError): NeptuneLogger(run=Run(), project="redundant project") with self.assertRaises(ValueError): NeptuneLogger(run=Run(), api_key="redundant api key") with self.assertRaises(ValueError): NeptuneLogger(run=Run(), name="redundant api name") with self.assertRaises(ValueError): NeptuneLogger(run=Run(), foo="random **kwarg") # this should work NeptuneLogger(run=Run()) NeptuneLogger(project="foo") NeptuneLogger(foo="bar") @staticmethod def _get_logger_with_mocks(**kwargs): logger = NeptuneLogger(**kwargs) run_instance_mock = MagicMock() logger._run_instance = run_instance_mock logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name" run_attr_mock = MagicMock() logger._run_instance.__getitem__.return_value = run_attr_mock return logger, run_instance_mock, run_attr_mock def test_neptune_additional_methods(self, neptune): logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project") logger.experiment["key1"].log(torch.ones(1)) run_instance_mock.__getitem__.assert_called_once_with("key1") run_instance_mock.__getitem__().log.assert_called_once_with(torch.ones(1)) def _fit_and_test(self, logger, model): trainer = Trainer(default_root_dir=self.tmpdir, max_epochs=1, limit_train_batches=0.05, logger=logger) assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune") trainer.fit(model) trainer.test(model) assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune") @pytest.mark.usefixtures("tmpdir_unittest_fixture") def test_neptune_leave_open_experiment_after_fit(self, neptune): """Verify that neptune experiment was NOT closed after training.""" # given logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project") # when self._fit_and_test( logger=logger, model=BoringModel(), ) # then assert run_instance_mock.stop.call_count == 0 @pytest.mark.usefixtures("tmpdir_unittest_fixture") def test_neptune_log_metrics_on_trained_model(self, neptune): """Verify that trained models do log data.""" # given class LoggingModel(BoringModel): def validation_epoch_end(self, outputs): self.log("some/key", 42) # and logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project") # when self._fit_and_test( logger=logger, model=LoggingModel(), ) # then run_instance_mock.__getitem__.assert_any_call("training/some/key") run_instance_mock.__getitem__.return_value.log.assert_has_calls([call(42)]) def test_log_hyperparams(self, neptune): params = {"foo": "bar", "nested_foo": {"bar": 42}} test_variants = [ ({}, "training/hyperparams"), ({"prefix": "custom_prefix"}, "custom_prefix/hyperparams"), ({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/hyperparams"), ] for prefix, hyperparams_key in test_variants: # given: logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project", **prefix) # when: log hyperparams logger.log_hyperparams(params) # then self.assertEqual(run_instance_mock.__setitem__.call_count, 1) self.assertEqual(run_instance_mock.__getitem__.call_count, 0) run_instance_mock.__setitem__.assert_called_once_with(hyperparams_key, params) def test_log_metrics(self, neptune): metrics = { "foo": 42, "bar": 555, } test_variants = [ ({}, ("training/foo", "training/bar")), ({"prefix": "custom_prefix"}, ("custom_prefix/foo", "custom_prefix/bar")), ({"prefix": "custom/nested/prefix"}, ("custom/nested/prefix/foo", "custom/nested/prefix/bar")), ] for prefix, (metrics_foo_key, metrics_bar_key) in test_variants: # given: logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks( api_key="test", project="project", **prefix ) # when: log metrics logger.log_metrics(metrics) # then: self.assertEqual(run_instance_mock.__setitem__.call_count, 0) self.assertEqual(run_instance_mock.__getitem__.call_count, 2) run_instance_mock.__getitem__.assert_any_call(metrics_foo_key) run_instance_mock.__getitem__.assert_any_call(metrics_bar_key) run_attr_mock.log.assert_has_calls([call(42), call(555)]) def test_log_model_summary(self, neptune): model = BoringModel() test_variants = [ ({}, "training/model/summary"), ({"prefix": "custom_prefix"}, "custom_prefix/model/summary"), ({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model/summary"), ] for prefix, model_summary_key in test_variants: # given: logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project", **prefix) file_from_content_mock = neptune.types.File.from_content() # when: log metrics logger.log_model_summary(model) # then: self.assertEqual(run_instance_mock.__setitem__.call_count, 1) self.assertEqual(run_instance_mock.__getitem__.call_count, 0) run_instance_mock.__setitem__.assert_called_once_with(model_summary_key, file_from_content_mock) def test_after_save_checkpoint(self, neptune): test_variants = [ ({}, "training/model"), ({"prefix": "custom_prefix"}, "custom_prefix/model"), ({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model"), ] for prefix, model_key_prefix in test_variants: # given: logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks( api_key="test", project="project", **prefix ) cb_mock = MagicMock( dirpath="path/to/models", last_model_path="path/to/models/last", best_k_models={ "path/to/models/model1": None, "path/to/models/model2/with/slashes": None, }, best_model_path="path/to/models/best_model", best_model_score=None, ) # when: save checkpoint logger.after_save_checkpoint(cb_mock) # then: self.assertEqual(run_instance_mock.__setitem__.call_count, 1) self.assertEqual(run_instance_mock.__getitem__.call_count, 3) self.assertEqual(run_attr_mock.upload.call_count, 3) run_instance_mock.__setitem__.assert_called_once_with( f"{model_key_prefix}/best_model_path", "path/to/models/best_model" ) run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/last") run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1") run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes") run_attr_mock.upload.assert_has_calls( [ call("path/to/models/last"), call("path/to/models/model1"), call("path/to/models/model2/with/slashes"), ] ) def test_save_dir(self, neptune): # given logger = NeptuneLogger(api_key="test", project="project") # expect self.assertEqual(logger.save_dir, os.path.join(os.getcwd(), ".neptune")) class TestNeptuneLoggerDeprecatedUsages(unittest.TestCase): @staticmethod def _assert_legacy_usage(callback, *args, **kwargs): with pytest.raises(ValueError): callback(*args, **kwargs) def test_legacy_kwargs(self): legacy_neptune_kwargs = [ # NeptuneLegacyLogger kwargs "project_name", "offline_mode", "experiment_name", "experiment_id", "params", "properties", "upload_source_files", "abort_callback", "logger", "upload_stdout", "upload_stderr", "send_hardware_metrics", "run_monitoring_thread", "handle_uncaught_exceptions", "git_info", "hostname", "notebook_id", "notebook_path", # NeptuneLogger from neptune-pytorch-lightning package kwargs "base_namespace", "close_after_fit", ] for legacy_kwarg in legacy_neptune_kwargs: self._assert_legacy_usage(NeptuneLogger, **{legacy_kwarg: None}) @patch("pytorch_lightning.loggers.neptune.warnings") @patch("pytorch_lightning.loggers.neptune.NeptuneFile") @patch("pytorch_lightning.loggers.neptune.neptune") def test_legacy_functions(self, neptune, neptune_file_mock, warnings_mock): logger = NeptuneLogger(api_key="test", project="project") # test deprecated functions which will be shut down in pytorch-lightning 1.7.0 attr_mock = logger._run_instance.__getitem__ attr_mock.reset_mock() fake_image = {} logger.log_metric("metric", 42) logger.log_text("text", "some string") logger.log_image("image_obj", fake_image) logger.log_image("image_str", "img/path") logger.log_artifact("artifact", "some/path") assert attr_mock.call_count == 5 assert warnings_mock.warn.call_count == 5 attr_mock.assert_has_calls( [ call("training/metric"), call().log(42, step=None), call("training/text"), call().log("some string", step=None), call("training/image_obj"), call().log(fake_image, step=None), call("training/image_str"), call().log(neptune_file_mock(), step=None), call("training/artifacts/artifact"), call().log("some/path"), ] ) # test Exception raising functions functions self._assert_legacy_usage(logger.set_property) self._assert_legacy_usage(logger.append_tags) class TestNeptuneLoggerUtils(unittest.TestCase): def test__get_full_model_name(self): # given: SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"]) test_input_data = [ ("key.ext", "foo/bar/key.ext", SimpleCheckpoint(dirpath="foo/bar")), ("key/in/parts.ext", "foo/bar/key/in/parts.ext", SimpleCheckpoint(dirpath="foo/bar")), ] # expect: for expected_model_name, *key_and_path in test_input_data: self.assertEqual(NeptuneLogger._get_full_model_name(*key_and_path), expected_model_name) def test__get_full_model_names_from_exp_structure(self): # given: input_dict = { "foo": { "bar": { "lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}}, "lvl1_2": "some non important value", }, "other_non_important": {"val100": 100}, }, "other_non_important": {"val42": 42}, } expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"} # expect: self.assertEqual(NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar"), expected_keys)