Fixed NeptuneLogger when using DDP (#11030)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
62f1e82e03
commit
3cc69f992b
|
@ -293,6 +293,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed `NeptuneLogger` when using DDP ([#11030](https://github.com/PyTorchLightning/pytorch-lightning/pull/11030))
|
||||
|
||||
|
||||
- Fixed security vulnerabilities CVE-2020-1747 and CVE-2020-14343 caused by the `PyYAML` dependency ([#11099](https://github.com/PyTorchLightning/pytorch-lightning/pull/11099))
|
||||
|
||||
|
||||
|
|
|
@ -51,4 +51,4 @@ dependencies:
|
|||
- mlflow>=1.0.0
|
||||
- comet_ml>=3.1.12
|
||||
- wandb>=0.8.21
|
||||
- neptune-client>=0.4.109
|
||||
- neptune-client>=0.10.0
|
||||
|
|
|
@ -44,7 +44,7 @@ if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9:
|
|||
from neptune.new.types import File as NeptuneFile
|
||||
except ModuleNotFoundError:
|
||||
import neptune
|
||||
from neptune.exceptions import NeptuneLegacyProjectException
|
||||
from neptune.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException
|
||||
from neptune.run import Run
|
||||
from neptune.types import File as NeptuneFile
|
||||
else:
|
||||
|
@ -266,51 +266,64 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
prefix: str = "training",
|
||||
**neptune_run_kwargs,
|
||||
):
|
||||
|
||||
# verify if user passed proper init arguments
|
||||
self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
|
||||
if neptune is None:
|
||||
raise ModuleNotFoundError(
|
||||
"You want to use the `Neptune` logger which is not installed yet, install it with"
|
||||
" `pip install neptune-client`."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self._log_model_checkpoints = log_model_checkpoints
|
||||
self._prefix = prefix
|
||||
self._run_name = name
|
||||
self._project_name = project
|
||||
self._api_key = api_key
|
||||
self._run_instance = run
|
||||
self._neptune_run_kwargs = neptune_run_kwargs
|
||||
self._run_short_id = None
|
||||
|
||||
self._run_instance = self._init_run_instance(api_key, project, name, run, neptune_run_kwargs)
|
||||
if self._run_instance is not None:
|
||||
self._retrieve_run_data()
|
||||
|
||||
self._run_short_id = self.run._short_id # skipcq: PYL-W0212
|
||||
# make sure that we've log integration version for outside `Run` instances
|
||||
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__
|
||||
|
||||
def _retrieve_run_data(self):
|
||||
try:
|
||||
self.run.wait()
|
||||
self._run_instance.wait()
|
||||
self._run_short_id = self.run._short_id # skipcq: PYL-W0212
|
||||
self._run_name = self._run_instance["sys/name"].fetch()
|
||||
except NeptuneOfflineModeFetchException:
|
||||
self._run_name = "offline-name"
|
||||
|
||||
def _init_run_instance(self, api_key, project, name, run, neptune_run_kwargs) -> Run:
|
||||
if run is not None:
|
||||
run_instance = run
|
||||
else:
|
||||
try:
|
||||
run_instance = neptune.init(
|
||||
project=project,
|
||||
api_token=api_key,
|
||||
name=name,
|
||||
**neptune_run_kwargs,
|
||||
)
|
||||
except NeptuneLegacyProjectException as e:
|
||||
raise TypeError(
|
||||
f"""Project {project} has not been migrated to the new structure.
|
||||
You can still integrate it with the Neptune logger using legacy Python API
|
||||
available as part of neptune-contrib package:
|
||||
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
|
||||
"""
|
||||
) from e
|
||||
@property
|
||||
def _neptune_init_args(self):
|
||||
args = {}
|
||||
# Backward compatibility in case of previous version retrieval
|
||||
try:
|
||||
args = self._neptune_run_kwargs
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# make sure that we've log integration version for both newly created and outside `Run` instances
|
||||
run_instance[_INTEGRATION_VERSION_KEY] = __version__
|
||||
if self._project_name is not None:
|
||||
args["project"] = self._project_name
|
||||
|
||||
# keep api_key and project, they will be required when resuming Run for pickled logger
|
||||
self._api_key = api_key
|
||||
self._project_name = run_instance._project_name # skipcq: PYL-W0212
|
||||
if self._api_key is not None:
|
||||
args["api_token"] = self._api_key
|
||||
|
||||
return run_instance
|
||||
if self._run_short_id is not None:
|
||||
args["run"] = self._run_short_id
|
||||
|
||||
# Backward compatibility in case of previous version retrieval
|
||||
try:
|
||||
if self._run_name is not None:
|
||||
args["name"] = self._run_name
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return args
|
||||
|
||||
def _construct_path_with_prefix(self, *keys) -> str:
|
||||
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
|
||||
|
@ -379,7 +392,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self._run_instance = neptune.init(project=self._project_name, api_token=self._api_key, run=self._run_short_id)
|
||||
self._run_instance = neptune.init(**self._neptune_init_args)
|
||||
|
||||
@property
|
||||
@rank_zero_experiment
|
||||
|
@ -412,8 +425,23 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
return self.run
|
||||
|
||||
@property
|
||||
@rank_zero_experiment
|
||||
def run(self) -> Run:
|
||||
return self._run_instance
|
||||
try:
|
||||
if not self._run_instance:
|
||||
self._run_instance = neptune.init(**self._neptune_init_args)
|
||||
self._retrieve_run_data()
|
||||
# make sure that we've log integration version for newly created
|
||||
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__
|
||||
|
||||
return self._run_instance
|
||||
except NeptuneLegacyProjectException as e:
|
||||
raise TypeError(
|
||||
f"Project {self._project_name} has not been migrated to the new structure."
|
||||
" You can still integrate it with the Neptune logger using legacy Python API"
|
||||
" available as part of neptune-contrib package:"
|
||||
" https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n"
|
||||
) from e
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # skipcq: PYL-W0221
|
||||
|
@ -474,12 +502,12 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
for key, val in metrics.items():
|
||||
# `step` is ignored because Neptune expects strictly increasing step values which
|
||||
# Lightning does not always guarantee.
|
||||
self.experiment[key].log(val)
|
||||
self.run[key].log(val)
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
if status:
|
||||
self.experiment[self._construct_path_with_prefix("status")] = status
|
||||
self.run[self._construct_path_with_prefix("status")] = status
|
||||
|
||||
super().finalize(status)
|
||||
|
||||
|
@ -493,12 +521,14 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
"""
|
||||
return os.path.join(os.getcwd(), ".neptune")
|
||||
|
||||
@rank_zero_only
|
||||
def log_model_summary(self, model, max_depth=-1):
|
||||
model_str = str(ModelSummary(model=model, max_depth=max_depth))
|
||||
self.experiment[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
|
||||
self.run[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
|
||||
content=model_str, extension="txt"
|
||||
)
|
||||
|
||||
@rank_zero_only
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
|
||||
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
|
||||
|
||||
|
@ -515,35 +545,33 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
if checkpoint_callback.last_model_path:
|
||||
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
|
||||
file_names.add(model_last_name)
|
||||
self.experiment[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
|
||||
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
|
||||
|
||||
# save best k models
|
||||
for key in checkpoint_callback.best_k_models.keys():
|
||||
model_name = self._get_full_model_name(key, checkpoint_callback)
|
||||
file_names.add(model_name)
|
||||
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(key)
|
||||
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
|
||||
|
||||
# log best model path and checkpoint
|
||||
if checkpoint_callback.best_model_path:
|
||||
self.experiment[
|
||||
self._construct_path_with_prefix("model/best_model_path")
|
||||
] = checkpoint_callback.best_model_path
|
||||
self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path
|
||||
|
||||
model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
|
||||
file_names.add(model_name)
|
||||
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
|
||||
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
|
||||
|
||||
# remove old models logged to experiment if they are not part of best k models at this point
|
||||
if self.experiment.exists(checkpoints_namespace):
|
||||
exp_structure = self.experiment.get_structure()
|
||||
if self.run.exists(checkpoints_namespace):
|
||||
exp_structure = self.run.get_structure()
|
||||
uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace)
|
||||
|
||||
for file_to_drop in list(uploaded_model_names - file_names):
|
||||
del self.experiment[f"{checkpoints_namespace}/{file_to_drop}"]
|
||||
del self.run[f"{checkpoints_namespace}/{file_to_drop}"]
|
||||
|
||||
# log best model score
|
||||
if checkpoint_callback.best_model_score:
|
||||
self.experiment[self._construct_path_with_prefix("model/best_model_score")] = (
|
||||
self.run[self._construct_path_with_prefix("model/best_model_score")] = (
|
||||
checkpoint_callback.best_model_score.cpu().detach().numpy()
|
||||
)
|
||||
|
||||
|
@ -637,13 +665,11 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
self._signal_deprecated_api_usage("log_artifact", f"logger.run['{key}].log('path_to_file')")
|
||||
self.run[key].log(destination)
|
||||
|
||||
@rank_zero_only
|
||||
def set_property(self, *args, **kwargs):
|
||||
self._signal_deprecated_api_usage(
|
||||
"log_artifact", f"logger.run['{self._prefix}/{self.PARAMETERS_KEY}/key'].log(value)", raise_exception=True
|
||||
)
|
||||
|
||||
@rank_zero_only
|
||||
def append_tags(self, *args, **kwargs):
|
||||
self._signal_deprecated_api_usage(
|
||||
"append_tags", "logger.run['sys/tags'].add(['foo', 'bar'])", raise_exception=True
|
||||
|
|
|
@ -47,6 +47,8 @@ def _get_logger_args(logger_class, save_dir):
|
|||
logger_args.update(offline_mode=True)
|
||||
if "offline" in inspect.getfullargspec(logger_class).args:
|
||||
logger_args.update(offline=True)
|
||||
if issubclass(logger_class, NeptuneLogger):
|
||||
logger_args.update(mode="offline")
|
||||
return logger_args
|
||||
|
||||
|
||||
|
@ -328,7 +330,9 @@ class RankZeroLoggerCheck(Callback):
|
|||
|
||||
|
||||
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
|
||||
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
|
||||
@pytest.mark.parametrize(
|
||||
"logger_class", [CometLogger, CSVLogger, MLFlowLogger, NeptuneLogger, TensorBoardLogger, TestTubeLogger]
|
||||
)
|
||||
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
|
||||
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
|
||||
_patch_comet_atexit(monkeypatch)
|
||||
|
|
|
@ -77,7 +77,7 @@ def tmpdir_unittest_fixture(request, tmpdir):
|
|||
class TestNeptuneLogger(unittest.TestCase):
|
||||
def test_neptune_online(self, neptune):
|
||||
logger = NeptuneLogger(api_key="test", project="project")
|
||||
created_run_mock = logger._run_instance
|
||||
created_run_mock = logger.run
|
||||
|
||||
self.assertEqual(logger._run_instance, created_run_mock)
|
||||
self.assertEqual(logger.name, "Run test name")
|
||||
|
@ -109,7 +109,7 @@ class TestNeptuneLogger(unittest.TestCase):
|
|||
pickled_logger = pickle.dumps(logger)
|
||||
unpickled = pickle.loads(pickled_logger)
|
||||
|
||||
neptune.init.assert_called_once_with(project="test-project", api_token=None, run="TEST-42")
|
||||
neptune.init.assert_called_once_with(name="Test name", run=unpickleable_run._short_id)
|
||||
self.assertIsNotNone(unpickled.experiment)
|
||||
|
||||
@patch("pytorch_lightning.loggers.neptune.Run", Run)
|
||||
|
@ -360,7 +360,7 @@ class TestNeptuneLoggerDeprecatedUsages(unittest.TestCase):
|
|||
logger = NeptuneLogger(api_key="test", project="project")
|
||||
|
||||
# test deprecated functions which will be shut down in pytorch-lightning 1.7.0
|
||||
attr_mock = logger._run_instance.__getitem__
|
||||
attr_mock = logger.run.__getitem__
|
||||
attr_mock.reset_mock()
|
||||
fake_image = {}
|
||||
|
||||
|
|
Loading…
Reference in New Issue