diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c4255af8e..6a8cf96c8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297)) +- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620)) + + - Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861)) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 760a621db6..aca37802ca 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -14,11 +14,15 @@ import inspect import os from abc import ABC -from argparse import ArgumentParser, Namespace +from argparse import ArgumentParser +from argparse import Namespace from typing import cast, List, Optional, Type, TypeVar, Union from pytorch_lightning.accelerators.legacy.accelerator import Accelerator -from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import ProgressBarBase from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger @@ -26,13 +30,15 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import Checkpoint from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType -from pytorch_lightning.utilities.argparse import ( - add_argparse_args, - from_argparse_args, - parse_argparser, - parse_env_variables, -) +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.argparse import add_argparse_args +from pytorch_lightning.utilities.argparse import from_argparse_args +from pytorch_lightning.utilities.argparse import parse_argparser +from pytorch_lightning.utilities.argparse import parse_env_variables from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_helpers import is_overridden @@ -193,7 +199,20 @@ class TrainerProperties(ABC): """ Read-only for progress bar metrics. """ ref_model = self.get_model() ref_model = cast(LightningModule, ref_model) - return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) + + standard_metrics = ref_model.get_progress_bar_dict() + logged_metrics = self.progress_bar_metrics + duplicates = list(standard_metrics.keys() & logged_metrics.keys()) + if duplicates: + rank_zero_warn( + f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" + f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " + f" If this is undesired, change the name or override `get_progress_bar_dict()`" + f" in `LightingModule`.", UserWarning + ) + all_metrics = dict(**standard_metrics) + all_metrics.update(**logged_metrics) + return all_metrics @property def disable_validation(self) -> bool: diff --git a/tests/accelerators/legacy/test_multi_nodes_gpu.py b/tests/accelerators/legacy/test_multi_nodes_gpu.py index f17ac42fcb..d9387df2b9 100644 --- a/tests/accelerators/legacy/test_multi_nodes_gpu.py +++ b/tests/accelerators/legacy/test_multi_nodes_gpu.py @@ -26,7 +26,9 @@ from pytorch_lightning import Trainer from tests.base.boring_model import BoringModel # noqa: E402 -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_logging_sync_dist_true_ddp(tmpdir): """ Tests to ensure that the sync_dist flag works with CPU (should just return the original value) @@ -62,7 +64,9 @@ def test_logging_sync_dist_true_ddp(tmpdir): assert trainer.logged_metrics['bar'] == fake_result -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test__validation_step__log(tmpdir): """ Tests that validation_step can log diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 106a299e43..ffdaea8c52 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -28,7 +28,8 @@ from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base.boring_model import BoringModel, RandomDataset +from tests.base.boring_model import BoringModel +from tests.base.boring_model import RandomDataset def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable: @@ -454,3 +455,21 @@ def test_metrics_holder(to_float, tmpdir): assert excepted_function(metrics["x"]) assert excepted_function(metrics["y"]) assert excepted_function(metrics["z"]) + + +def test_logging_to_progress_bar_with_reserved_key(tmpdir): + """ Test that logging a metric with a reserved name to the progress bar raises a warning. """ + class TestModel(BoringModel): + + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + self.log("loss", output["loss"], prog_bar=True) + return output + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + ) + with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): + trainer.fit(model) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index dcafa85092..471792da9c 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -16,13 +16,13 @@ from unittest.mock import patch import pytest -import pytorch_lightning.utilities.xla_device as xla_utils +import pytorch_lightning.utilities.xla_device_utils as xla_utils from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities import _XLA_AVAILABLE from tests.base.develop_utils import pl_multi_process_test -@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") +@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): """Check tpu_device_exists returns None when torch_xla is not available""" assert xla_utils.XLADeviceUtils.tpu_device_exists() is None