lightning/tests/tests_pytorch/loggers/test_neptune.py

382 lines
15 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import unittest
from collections import namedtuple
from unittest import mock
from unittest.mock import call, MagicMock, patch
import pytest
import torch
from pytorch_lightning import __version__, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loggers import NeptuneLogger
def fetchable_paths(value):
if value == "sys/id":
return MagicMock(fetch=MagicMock(return_value="TEST-1"))
elif value == "sys/name":
return MagicMock(fetch=MagicMock(return_value="Run test name"))
return MagicMock()
def create_run_mock(mode="async", **kwargs):
if mode == "offline":
return MagicMock(__getitem__=MagicMock(side_effect=fetchable_paths), exists=MagicMock(return_value=False))
else:
return MagicMock(__getitem__=MagicMock(side_effect=fetchable_paths), exists=MagicMock(return_value=True))
def create_neptune_mock():
"""Mock with provides nice `logger.name` and `logger.version` values. Additionally, it allows `mode` as an
argument to test different Neptune modes.
Mostly due to fact, that windows tests were failing with MagicMock based strings, which were used to create local
directories in FS.
"""
return MagicMock(init_run=MagicMock(side_effect=create_run_mock))
class Run:
_project_name = "test-project"
def __setitem__(self, key, value):
# called once
assert key == "source_code/integrations/pytorch-lightning"
assert value == __version__
def wait(self):
# for test purposes
pass
def __getitem__(self, item):
if item == "sys/name":
return MagicMock(fetch=MagicMock(return_value="Test name"))
elif item == "sys/id":
return MagicMock(fetch=MagicMock(return_value="TEST-42"))
assert False, f"Unexpected call '{item}'"
def __getstate__(self):
raise pickle.PicklingError("Runs are unpickleable")
def exists(self, value):
return True
@pytest.fixture
def tmpdir_unittest_fixture(request, tmpdir):
"""Proxy for pytest `tmpdir` fixture between pytest and unittest.
Resources:
* https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-fixture
* https://towardsdatascience.com/mixing-pytest-fixture-and-unittest-testcase-for-selenium-test-9162218e8c8e
"""
request.cls.tmpdir = tmpdir
@patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock)
class TestNeptuneLogger(unittest.TestCase):
def run(self, *args, **kwargs):
with mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True):
super().run(*args, **kwargs)
def test_neptune_online(self, neptune):
logger = NeptuneLogger(api_key="test", project="project")
created_run_mock = logger.run
self.assertEqual(logger._run_instance, created_run_mock)
created_run_mock.exists.assert_called_once_with("sys/id")
self.assertEqual(logger.name, "Run test name")
self.assertEqual(logger.version, "TEST-1")
self.assertEqual(neptune.init_run.call_count, 1)
self.assertEqual(created_run_mock.__getitem__.call_count, 2)
self.assertEqual(created_run_mock.__setitem__.call_count, 1)
created_run_mock.__getitem__.assert_has_calls([call("sys/id"), call("sys/name")], any_order=True)
created_run_mock.__setitem__.assert_called_once_with("source_code/integrations/pytorch-lightning", __version__)
def test_neptune_offline(self, neptune):
logger = NeptuneLogger(mode="offline")
created_run_mock = logger.run
logger.experiment["foo"] = "bar"
created_run_mock.exists.assert_called_once_with("sys/id")
self.assertEqual(logger._run_short_id, "OFFLINE")
self.assertEqual(logger._run_name, "offline-name")
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_online_with_custom_run(self, neptune):
created_run = Run()
logger = NeptuneLogger(run=created_run)
assert logger._run_instance == created_run
self.assertEqual(logger._run_instance, created_run)
self.assertEqual(logger.version, "TEST-42")
self.assertEqual(neptune.init_run.call_count, 0)
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_neptune_pickling(self, neptune):
unpickleable_run = Run()
logger = NeptuneLogger(run=unpickleable_run)
self.assertEqual(0, neptune.init_run.call_count)
pickled_logger = pickle.dumps(logger)
unpickled = pickle.loads(pickled_logger)
neptune.init_run.assert_called_once_with(name="Test name", run="TEST-42")
self.assertIsNotNone(unpickled.experiment)
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_online_with_wrong_kwargs(self, neptune):
"""Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable
in init."""
with self.assertRaises(ValueError):
NeptuneLogger(run="some string")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), project="redundant project")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), api_key="redundant api key")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), name="redundant api name")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), foo="random **kwarg")
# this should work
NeptuneLogger(run=Run())
NeptuneLogger(project="foo")
NeptuneLogger(foo="bar")
@staticmethod
def _get_logger_with_mocks(**kwargs):
logger = NeptuneLogger(**kwargs)
run_instance_mock = MagicMock()
logger._run_instance = run_instance_mock
logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name"
run_attr_mock = MagicMock()
logger._run_instance.__getitem__.return_value = run_attr_mock
return logger, run_instance_mock, run_attr_mock
def test_neptune_additional_methods(self, neptune):
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
logger.experiment["key1"].log(torch.ones(1))
run_instance_mock.__getitem__.assert_called_once_with("key1")
run_instance_mock.__getitem__().log.assert_called_once_with(torch.ones(1))
def _fit_and_test(self, logger, model):
trainer = Trainer(default_root_dir=self.tmpdir, max_epochs=1, limit_train_batches=0.05, logger=logger)
assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune")
trainer.fit(model)
trainer.test(model)
assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune")
@pytest.mark.usefixtures("tmpdir_unittest_fixture")
def test_neptune_leave_open_experiment_after_fit(self, neptune):
"""Verify that neptune experiment was NOT closed after training."""
# given
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
# when
self._fit_and_test(
logger=logger,
model=BoringModel(),
)
# then
assert run_instance_mock.stop.call_count == 0
@pytest.mark.usefixtures("tmpdir_unittest_fixture")
def test_neptune_log_metrics_on_trained_model(self, neptune):
"""Verify that trained models do log data."""
# given
class LoggingModel(BoringModel):
def validation_epoch_end(self, outputs):
self.log("some/key", 42)
# and
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
# when
self._fit_and_test(
logger=logger,
model=LoggingModel(),
)
# then
run_instance_mock.__getitem__.assert_any_call("training/some/key")
run_instance_mock.__getitem__.return_value.log.assert_has_calls([call(42)])
def test_log_hyperparams(self, neptune):
params = {"foo": "bar", "nested_foo": {"bar": 42}}
test_variants = [
({}, "training/hyperparams"),
({"prefix": "custom_prefix"}, "custom_prefix/hyperparams"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/hyperparams"),
]
for prefix, hyperparams_key in test_variants:
# given:
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project", **prefix)
# when: log hyperparams
logger.log_hyperparams(params)
# then
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 0)
run_instance_mock.__setitem__.assert_called_once_with(hyperparams_key, params)
def test_log_metrics(self, neptune):
metrics = {
"foo": 42,
"bar": 555,
}
test_variants = [
({}, ("training/foo", "training/bar")),
({"prefix": "custom_prefix"}, ("custom_prefix/foo", "custom_prefix/bar")),
({"prefix": "custom/nested/prefix"}, ("custom/nested/prefix/foo", "custom/nested/prefix/bar")),
]
for prefix, (metrics_foo_key, metrics_bar_key) in test_variants:
# given:
logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks(
api_key="test", project="project", **prefix
)
# when: log metrics
logger.log_metrics(metrics)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 0)
self.assertEqual(run_instance_mock.__getitem__.call_count, 2)
run_instance_mock.__getitem__.assert_any_call(metrics_foo_key)
run_instance_mock.__getitem__.assert_any_call(metrics_bar_key)
run_attr_mock.log.assert_has_calls([call(42), call(555)])
def test_log_model_summary(self, neptune):
model = BoringModel()
test_variants = [
({}, "training/model/summary"),
({"prefix": "custom_prefix"}, "custom_prefix/model/summary"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model/summary"),
]
for prefix, model_summary_key in test_variants:
# given:
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project", **prefix)
file_from_content_mock = neptune.types.File.from_content()
# when: log metrics
logger.log_model_summary(model)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 0)
run_instance_mock.__setitem__.assert_called_once_with(model_summary_key, file_from_content_mock)
def test_after_save_checkpoint(self, neptune):
test_variants = [
({}, "training/model"),
({"prefix": "custom_prefix"}, "custom_prefix/model"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model"),
]
for prefix, model_key_prefix in test_variants:
# given:
logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks(
api_key="test", project="project", **prefix
)
models_root_dir = os.path.join("path", "to", "models")
cb_mock = MagicMock(
dirpath=models_root_dir,
last_model_path=os.path.join(models_root_dir, "last"),
best_k_models={
f"{os.path.join(models_root_dir, 'model1')}": None,
f"{os.path.join(models_root_dir, 'model2/with/slashes')}": None,
},
best_model_path=os.path.join(models_root_dir, "best_model"),
best_model_score=None,
)
# when: save checkpoint
logger.after_save_checkpoint(cb_mock)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 4)
self.assertEqual(run_attr_mock.upload.call_count, 4)
run_instance_mock.__setitem__.assert_called_once_with(
f"{model_key_prefix}/best_model_path", os.path.join(models_root_dir, "best_model")
)
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/last")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/best_model")
run_attr_mock.upload.assert_has_calls(
[
call(os.path.join(models_root_dir, "last")),
call(os.path.join(models_root_dir, "model1")),
call(os.path.join(models_root_dir, "model2/with/slashes")),
call(os.path.join(models_root_dir, "best_model")),
]
)
def test_save_dir(self, neptune):
# given
logger = NeptuneLogger(api_key="test", project="project")
# expect
self.assertEqual(logger.save_dir, os.path.join(os.getcwd(), ".neptune"))
class TestNeptuneLoggerUtils(unittest.TestCase):
def test__get_full_model_name(self):
# given:
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
test_input_data = [
("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
(
"key/in/parts",
os.path.join("foo", "bar", "key/in/parts.ext"),
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
),
]
# expect:
for expected_model_name, *key_and_path in test_input_data:
self.assertEqual(NeptuneLogger._get_full_model_name(*key_and_path), expected_model_name)
def test__get_full_model_names_from_exp_structure(self):
# given:
input_dict = {
"foo": {
"bar": {
"lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}},
"lvl1_2": "some non important value",
},
"other_non_important": {"val100": 100},
},
"other_non_important": {"val42": 42},
}
expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
# expect:
self.assertEqual(NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar"), expected_keys)