Remove access to `_short_id` in NeptuneLogger (#11517)

This commit is contained in:
Rafał Jankowski 2022-01-20 13:07:42 +01:00 committed by GitHub
parent 16a04b29eb
commit e78d658c8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 21 deletions

View File

@ -416,6 +416,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `Strategy.on_tpu` property ([#11536](https://github.com/PyTorchLightning/pytorch-lightning/pull/11536))
- Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517))
### Fixed
- Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the `PyYAML` dependency ([#11099](https://github.com/PyTorchLightning/pytorch-lightning/pull/11099))

View File

@ -293,9 +293,10 @@ class NeptuneLogger(LightningLoggerBase):
def _retrieve_run_data(self):
try:
self._run_instance.wait()
self._run_short_id = self.run._short_id # skipcq: PYL-W0212
self._run_short_id = self._run_instance["sys/id"].fetch()
self._run_name = self._run_instance["sys/name"].fetch()
except NeptuneOfflineModeFetchException:
self._run_short_id = "OFFLINE"
self._run_name = "offline-name"
@property

View File

@ -387,9 +387,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
# Neptune
with mock.patch("pytorch_lightning.loggers.neptune.neptune"):
logger = _instantiate_logger(NeptuneLogger, api_key="test", project="project", save_dir=tmpdir, prefix=prefix)
assert logger.experiment.__getitem__.call_count == 1
logger.log_metrics({"test": 1.0}, step=0)
assert logger.experiment.__getitem__.call_count == 2
logger.log_metrics({"test": 1.0}, step=0)
assert logger.experiment.__getitem__.call_count == 3
logger.experiment.__getitem__.assert_called_with("tmp/test")
logger.experiment.__getitem__().log.assert_called_once_with(1.0)

View File

@ -25,24 +25,24 @@ from pytorch_lightning.loggers import NeptuneLogger
from tests.helpers import BoringModel
def fetchable_paths(value):
if value == "sys/id":
return MagicMock(fetch=MagicMock(return_value="TEST-1"))
elif value == "sys/name":
return MagicMock(fetch=MagicMock(return_value="Run test name"))
return MagicMock()
def create_neptune_mock():
"""Mock with provides nice `logger.name` and `logger.version` values.
Mostly due to fact, that windows tests were failing with MagicMock based strings, which were used to create local
directories in FS.
"""
return MagicMock(
init=MagicMock(
return_value=MagicMock(
__getitem__=MagicMock(return_value=MagicMock(fetch=MagicMock(return_value="Run test name"))),
_short_id="TEST-1",
)
)
)
return MagicMock(init=MagicMock(return_value=MagicMock(__getitem__=MagicMock(side_effect=fetchable_paths))))
class Run:
_short_id = "TEST-42"
_project_name = "test-project"
def __setitem__(self, key, value):
@ -55,9 +55,12 @@ class Run:
pass
def __getitem__(self, item):
# called once
assert item == "sys/name"
return MagicMock(fetch=MagicMock(return_value="Test name"))
if item == "sys/name":
return MagicMock(fetch=MagicMock(return_value="Test name"))
elif item == "sys/id":
return MagicMock(fetch=MagicMock(return_value="TEST-42"))
assert False, f"Unexpected call '{item}'"
def __getstate__(self):
raise pickle.PicklingError("Runs are unpickleable")
@ -83,11 +86,9 @@ class TestNeptuneLogger(unittest.TestCase):
self.assertEqual(logger.name, "Run test name")
self.assertEqual(logger.version, "TEST-1")
self.assertEqual(neptune.init.call_count, 1)
self.assertEqual(created_run_mock.__getitem__.call_count, 1)
self.assertEqual(created_run_mock.__getitem__.call_count, 2)
self.assertEqual(created_run_mock.__setitem__.call_count, 1)
created_run_mock.__getitem__.assert_called_once_with(
"sys/name",
)
created_run_mock.__getitem__.assert_has_calls([call("sys/id"), call("sys/name")], any_order=True)
created_run_mock.__setitem__.assert_called_once_with("source_code/integrations/pytorch-lightning", __version__)
@patch("pytorch_lightning.loggers.neptune.Run", Run)
@ -97,7 +98,7 @@ class TestNeptuneLogger(unittest.TestCase):
assert logger._run_instance == created_run
self.assertEqual(logger._run_instance, created_run)
self.assertEqual(logger.version, created_run._short_id)
self.assertEqual(logger.version, "TEST-42")
self.assertEqual(neptune.init.call_count, 0)
@patch("pytorch_lightning.loggers.neptune.Run", Run)
@ -109,7 +110,7 @@ class TestNeptuneLogger(unittest.TestCase):
pickled_logger = pickle.dumps(logger)
unpickled = pickle.loads(pickled_logger)
neptune.init.assert_called_once_with(name="Test name", run=unpickleable_run._short_id)
neptune.init.assert_called_once_with(name="Test name", run="TEST-42")
self.assertIsNotNone(unpickled.experiment)
@patch("pytorch_lightning.loggers.neptune.Run", Run)