lightning/tests/loggers/test_neptune.py

421 lines
16 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import unittest
from collections import namedtuple
from unittest.mock import call, MagicMock, patch
import pytest
import torch
from pytorch_lightning import __version__, Trainer
from pytorch_lightning.loggers import NeptuneLogger
from tests.helpers import BoringModel
def create_neptune_mock():
"""Mock with provides nice `logger.name` and `logger.version` values.
Mostly due to fact, that windows tests were failing with MagicMock based strings, which were used to create local
directories in FS.
"""
return MagicMock(
init=MagicMock(
return_value=MagicMock(
__getitem__=MagicMock(return_value=MagicMock(fetch=MagicMock(return_value="Run test name"))),
_short_id="TEST-1",
)
)
)
class Run:
_short_id = "TEST-42"
_project_name = "test-project"
def __setitem__(self, key, value):
# called once
assert key == "source_code/integrations/pytorch-lightning"
assert value == __version__
def wait(self):
# for test purposes
pass
def __getitem__(self, item):
# called once
assert item == "sys/name"
return MagicMock(fetch=MagicMock(return_value="Test name"))
def __getstate__(self):
raise pickle.PicklingError("Runs are unpickleable")
@pytest.fixture
def tmpdir_unittest_fixture(request, tmpdir):
"""Proxy for pytest `tmpdir` fixture between pytest and unittest.
Resources:
* https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-fixture
* https://towardsdatascience.com/mixing-pytest-fixture-and-unittest-testcase-for-selenium-test-9162218e8c8e
"""
request.cls.tmpdir = tmpdir
@patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock)
class TestNeptuneLogger(unittest.TestCase):
def test_neptune_online(self, neptune):
logger = NeptuneLogger(api_key="test", project="project")
created_run_mock = logger._run_instance
self.assertEqual(logger._run_instance, created_run_mock)
self.assertEqual(logger.name, "Run test name")
self.assertEqual(logger.version, "TEST-1")
self.assertEqual(neptune.init.call_count, 1)
self.assertEqual(created_run_mock.__getitem__.call_count, 1)
self.assertEqual(created_run_mock.__setitem__.call_count, 1)
created_run_mock.__getitem__.assert_called_once_with(
"sys/name",
)
created_run_mock.__setitem__.assert_called_once_with("source_code/integrations/pytorch-lightning", __version__)
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_online_with_custom_run(self, neptune):
created_run = Run()
logger = NeptuneLogger(run=created_run)
assert logger._run_instance == created_run
self.assertEqual(logger._run_instance, created_run)
self.assertEqual(logger.version, created_run._short_id)
self.assertEqual(neptune.init.call_count, 0)
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_neptune_pickling(self, neptune):
unpickleable_run = Run()
logger = NeptuneLogger(run=unpickleable_run)
self.assertEqual(0, neptune.init.call_count)
pickled_logger = pickle.dumps(logger)
unpickled = pickle.loads(pickled_logger)
neptune.init.assert_called_once_with(project="test-project", api_token=None, run="TEST-42")
self.assertIsNotNone(unpickled.experiment)
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_online_with_wrong_kwargs(self, neptune):
"""Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable
in init."""
with self.assertRaises(ValueError):
NeptuneLogger(run="some string")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), project="redundant project")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), api_key="redundant api key")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), name="redundant api name")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), foo="random **kwarg")
# this should work
NeptuneLogger(run=Run())
NeptuneLogger(project="foo")
NeptuneLogger(foo="bar")
@staticmethod
def _get_logger_with_mocks(**kwargs):
logger = NeptuneLogger(**kwargs)
run_instance_mock = MagicMock()
logger._run_instance = run_instance_mock
logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name"
run_attr_mock = MagicMock()
logger._run_instance.__getitem__.return_value = run_attr_mock
return logger, run_instance_mock, run_attr_mock
def test_neptune_additional_methods(self, neptune):
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
logger.experiment["key1"].log(torch.ones(1))
run_instance_mock.__getitem__.assert_called_once_with("key1")
run_instance_mock.__getitem__().log.assert_called_once_with(torch.ones(1))
def _fit_and_test(self, logger, model):
trainer = Trainer(default_root_dir=self.tmpdir, max_epochs=1, limit_train_batches=0.05, logger=logger)
assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune")
trainer.fit(model)
trainer.test(model)
assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune")
@pytest.mark.usefixtures("tmpdir_unittest_fixture")
def test_neptune_leave_open_experiment_after_fit(self, neptune):
"""Verify that neptune experiment was NOT closed after training."""
# given
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
# when
self._fit_and_test(
logger=logger,
model=BoringModel(),
)
# then
assert run_instance_mock.stop.call_count == 0
@pytest.mark.usefixtures("tmpdir_unittest_fixture")
def test_neptune_log_metrics_on_trained_model(self, neptune):
"""Verify that trained models do log data."""
# given
class LoggingModel(BoringModel):
def validation_epoch_end(self, outputs):
self.log("some/key", 42)
# and
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
# when
self._fit_and_test(
logger=logger,
model=LoggingModel(),
)
# then
run_instance_mock.__getitem__.assert_any_call("training/some/key")
run_instance_mock.__getitem__.return_value.log.assert_has_calls([call(42)])
def test_log_hyperparams(self, neptune):
params = {"foo": "bar", "nested_foo": {"bar": 42}}
test_variants = [
({}, "training/hyperparams"),
({"prefix": "custom_prefix"}, "custom_prefix/hyperparams"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/hyperparams"),
]
for prefix, hyperparams_key in test_variants:
# given:
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project", **prefix)
# when: log hyperparams
logger.log_hyperparams(params)
# then
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 0)
run_instance_mock.__setitem__.assert_called_once_with(hyperparams_key, params)
def test_log_metrics(self, neptune):
metrics = {
"foo": 42,
"bar": 555,
}
test_variants = [
({}, ("training/foo", "training/bar")),
({"prefix": "custom_prefix"}, ("custom_prefix/foo", "custom_prefix/bar")),
({"prefix": "custom/nested/prefix"}, ("custom/nested/prefix/foo", "custom/nested/prefix/bar")),
]
for prefix, (metrics_foo_key, metrics_bar_key) in test_variants:
# given:
logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks(
api_key="test", project="project", **prefix
)
# when: log metrics
logger.log_metrics(metrics)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 0)
self.assertEqual(run_instance_mock.__getitem__.call_count, 2)
run_instance_mock.__getitem__.assert_any_call(metrics_foo_key)
run_instance_mock.__getitem__.assert_any_call(metrics_bar_key)
run_attr_mock.log.assert_has_calls([call(42), call(555)])
def test_log_model_summary(self, neptune):
model = BoringModel()
test_variants = [
({}, "training/model/summary"),
({"prefix": "custom_prefix"}, "custom_prefix/model/summary"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model/summary"),
]
for prefix, model_summary_key in test_variants:
# given:
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project", **prefix)
file_from_content_mock = neptune.types.File.from_content()
# when: log metrics
logger.log_model_summary(model)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 0)
run_instance_mock.__setitem__.assert_called_once_with(model_summary_key, file_from_content_mock)
def test_after_save_checkpoint(self, neptune):
test_variants = [
({}, "training/model"),
({"prefix": "custom_prefix"}, "custom_prefix/model"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model"),
]
for prefix, model_key_prefix in test_variants:
# given:
logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks(
api_key="test", project="project", **prefix
)
cb_mock = MagicMock(
dirpath="path/to/models",
last_model_path="path/to/models/last",
best_k_models={
"path/to/models/model1": None,
"path/to/models/model2/with/slashes": None,
},
best_model_path="path/to/models/best_model",
best_model_score=None,
)
# when: save checkpoint
logger.after_save_checkpoint(cb_mock)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 3)
self.assertEqual(run_attr_mock.upload.call_count, 3)
run_instance_mock.__setitem__.assert_called_once_with(
f"{model_key_prefix}/best_model_path", "path/to/models/best_model"
)
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/last")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
run_attr_mock.upload.assert_has_calls(
[
call("path/to/models/last"),
call("path/to/models/model1"),
call("path/to/models/model2/with/slashes"),
]
)
def test_save_dir(self, neptune):
# given
logger = NeptuneLogger(api_key="test", project="project")
# expect
self.assertEqual(logger.save_dir, os.path.join(os.getcwd(), ".neptune"))
class TestNeptuneLoggerDeprecatedUsages(unittest.TestCase):
@staticmethod
def _assert_legacy_usage(callback, *args, **kwargs):
with pytest.raises(ValueError):
callback(*args, **kwargs)
def test_legacy_kwargs(self):
legacy_neptune_kwargs = [
# NeptuneLegacyLogger kwargs
"project_name",
"offline_mode",
"experiment_name",
"experiment_id",
"params",
"properties",
"upload_source_files",
"abort_callback",
"logger",
"upload_stdout",
"upload_stderr",
"send_hardware_metrics",
"run_monitoring_thread",
"handle_uncaught_exceptions",
"git_info",
"hostname",
"notebook_id",
"notebook_path",
# NeptuneLogger from neptune-pytorch-lightning package kwargs
"base_namespace",
"close_after_fit",
]
for legacy_kwarg in legacy_neptune_kwargs:
self._assert_legacy_usage(NeptuneLogger, **{legacy_kwarg: None})
@patch("pytorch_lightning.loggers.neptune.warnings")
@patch("pytorch_lightning.loggers.neptune.NeptuneFile")
@patch("pytorch_lightning.loggers.neptune.neptune")
def test_legacy_functions(self, neptune, neptune_file_mock, warnings_mock):
logger = NeptuneLogger(api_key="test", project="project")
# test deprecated functions which will be shut down in pytorch-lightning 1.7.0
attr_mock = logger._run_instance.__getitem__
attr_mock.reset_mock()
fake_image = {}
logger.log_metric("metric", 42)
logger.log_text("text", "some string")
logger.log_image("image_obj", fake_image)
logger.log_image("image_str", "img/path")
logger.log_artifact("artifact", "some/path")
assert attr_mock.call_count == 5
assert warnings_mock.warn.call_count == 5
attr_mock.assert_has_calls(
[
call("training/metric"),
call().log(42, step=None),
call("training/text"),
call().log("some string", step=None),
call("training/image_obj"),
call().log(fake_image, step=None),
call("training/image_str"),
call().log(neptune_file_mock(), step=None),
call("training/artifacts/artifact"),
call().log("some/path"),
]
)
# test Exception raising functions functions
self._assert_legacy_usage(logger.set_property)
self._assert_legacy_usage(logger.append_tags)
class TestNeptuneLoggerUtils(unittest.TestCase):
def test__get_full_model_name(self):
# given:
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
test_input_data = [
("key.ext", "foo/bar/key.ext", SimpleCheckpoint(dirpath="foo/bar")),
("key/in/parts.ext", "foo/bar/key/in/parts.ext", SimpleCheckpoint(dirpath="foo/bar")),
]
# expect:
for expected_model_name, *key_and_path in test_input_data:
self.assertEqual(NeptuneLogger._get_full_model_name(*key_and_path), expected_model_name)
def test__get_full_model_names_from_exp_structure(self):
# given:
input_dict = {
"foo": {
"bar": {
"lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}},
"lvl1_2": "some non important value",
},
"other_non_important": {"val100": 100},
},
"other_non_important": {"val42": 42},
}
expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
# expect:
self.assertEqual(NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar"), expected_keys)