fix error when logging to progress bar with reserved name (#5620)

* warn about duplicate metrics

* update changelog

* suggestions from rohit

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* multiple values in message

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-01-25 13:57:06 +01:00 committed by Jirka Borovec
parent 2c9f606af9
commit b3b48c188c
5 changed files with 60 additions and 15 deletions

View File

@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297))
- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))
- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861))

View File

@ -14,11 +14,15 @@
import inspect
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from argparse import ArgumentParser
from argparse import Namespace
from typing import cast, List, Optional, Type, TypeVar, Union
from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
@ -26,13 +30,15 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import Checkpoint
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
parse_argparser,
parse_env_variables,
)
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.argparse import add_argparse_args
from pytorch_lightning.utilities.argparse import from_argparse_args
from pytorch_lightning.utilities.argparse import parse_argparser
from pytorch_lightning.utilities.argparse import parse_env_variables
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -193,7 +199,20 @@ class TrainerProperties(ABC):
""" Read-only for progress bar metrics. """
ref_model = self.get_model()
ref_model = cast(LightningModule, ref_model)
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)
standard_metrics = ref_model.get_progress_bar_dict()
logged_metrics = self.progress_bar_metrics
duplicates = list(standard_metrics.keys() & logged_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
f" If this is undesired, change the name or override `get_progress_bar_dict()`"
f" in `LightingModule`.", UserWarning
)
all_metrics = dict(**standard_metrics)
all_metrics.update(**logged_metrics)
return all_metrics
@property
def disable_validation(self) -> bool:

View File

@ -26,7 +26,9 @@ from pytorch_lightning import Trainer
from tests.base.boring_model import BoringModel # noqa: E402
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_logging_sync_dist_true_ddp(tmpdir):
"""
Tests to ensure that the sync_dist flag works with CPU (should just return the original value)
@ -62,7 +64,9 @@ def test_logging_sync_dist_true_ddp(tmpdir):
assert trainer.logged_metrics['bar'] == fake_result
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test__validation_step__log(tmpdir):
"""
Tests that validation_step can log

View File

@ -28,7 +28,8 @@ from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDataset
from tests.base.boring_model import BoringModel
from tests.base.boring_model import RandomDataset
def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable:
@ -454,3 +455,21 @@ def test_metrics_holder(to_float, tmpdir):
assert excepted_function(metrics["x"])
assert excepted_function(metrics["y"])
assert excepted_function(metrics["z"])
def test_logging_to_progress_bar_with_reserved_key(tmpdir):
""" Test that logging a metric with a reserved name to the progress bar raises a warning. """
class TestModel(BoringModel):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
self.log("loss", output["loss"], prog_bar=True)
return output
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=2,
)
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
trainer.fit(model)

View File

@ -16,13 +16,13 @@ from unittest.mock import patch
import pytest
import pytorch_lightning.utilities.xla_device as xla_utils
import pytorch_lightning.utilities.xla_device_utils as xla_utils
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _XLA_AVAILABLE
from tests.base.develop_utils import pl_multi_process_test
@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None