fix error when logging to progress bar with reserved name (#5620)
* warn about duplicate metrics * update changelog * suggestions from rohit Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * multiple values in message * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
2c9f606af9
commit
b3b48c188c
|
@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297))
|
||||
|
||||
|
||||
- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))
|
||||
|
||||
|
||||
- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861))
|
||||
|
||||
|
||||
|
|
|
@ -14,11 +14,15 @@
|
|||
import inspect
|
||||
import os
|
||||
from abc import ABC
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from argparse import ArgumentParser
|
||||
from argparse import Namespace
|
||||
from typing import cast, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
|
||||
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import ProgressBarBase
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
|
@ -26,13 +30,15 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import Checkpoint
|
|||
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
add_argparse_args,
|
||||
from_argparse_args,
|
||||
parse_argparser,
|
||||
parse_env_variables,
|
||||
)
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import DeviceType
|
||||
from pytorch_lightning.utilities import DistributedType
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.argparse import add_argparse_args
|
||||
from pytorch_lightning.utilities.argparse import from_argparse_args
|
||||
from pytorch_lightning.utilities.argparse import parse_argparser
|
||||
from pytorch_lightning.utilities.argparse import parse_env_variables
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
||||
|
@ -193,7 +199,20 @@ class TrainerProperties(ABC):
|
|||
""" Read-only for progress bar metrics. """
|
||||
ref_model = self.get_model()
|
||||
ref_model = cast(LightningModule, ref_model)
|
||||
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)
|
||||
|
||||
standard_metrics = ref_model.get_progress_bar_dict()
|
||||
logged_metrics = self.progress_bar_metrics
|
||||
duplicates = list(standard_metrics.keys() & logged_metrics.keys())
|
||||
if duplicates:
|
||||
rank_zero_warn(
|
||||
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
|
||||
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
|
||||
f" If this is undesired, change the name or override `get_progress_bar_dict()`"
|
||||
f" in `LightingModule`.", UserWarning
|
||||
)
|
||||
all_metrics = dict(**standard_metrics)
|
||||
all_metrics.update(**logged_metrics)
|
||||
return all_metrics
|
||||
|
||||
@property
|
||||
def disable_validation(self) -> bool:
|
||||
|
|
|
@ -26,7 +26,9 @@ from pytorch_lightning import Trainer
|
|||
from tests.base.boring_model import BoringModel # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
|
||||
)
|
||||
def test_logging_sync_dist_true_ddp(tmpdir):
|
||||
"""
|
||||
Tests to ensure that the sync_dist flag works with CPU (should just return the original value)
|
||||
|
@ -62,7 +64,9 @@ def test_logging_sync_dist_true_ddp(tmpdir):
|
|||
assert trainer.logged_metrics['bar'] == fake_result
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
|
||||
)
|
||||
def test__validation_step__log(tmpdir):
|
||||
"""
|
||||
Tests that validation_step can log
|
||||
|
|
|
@ -28,7 +28,8 @@ from pytorch_lightning.trainer import Trainer
|
|||
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
from tests.base.boring_model import BoringModel
|
||||
from tests.base.boring_model import RandomDataset
|
||||
|
||||
|
||||
def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable:
|
||||
|
@ -454,3 +455,21 @@ def test_metrics_holder(to_float, tmpdir):
|
|||
assert excepted_function(metrics["x"])
|
||||
assert excepted_function(metrics["y"])
|
||||
assert excepted_function(metrics["z"])
|
||||
|
||||
|
||||
def test_logging_to_progress_bar_with_reserved_key(tmpdir):
|
||||
""" Test that logging a metric with a reserved name to the progress bar raises a warning. """
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
output = super().training_step(*args, **kwargs)
|
||||
self.log("loss", output["loss"], prog_bar=True)
|
||||
return output
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=2,
|
||||
)
|
||||
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -16,13 +16,13 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
import pytorch_lightning.utilities.xla_device as xla_utils
|
||||
import pytorch_lightning.utilities.xla_device_utils as xla_utils
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import _XLA_AVAILABLE
|
||||
from tests.base.develop_utils import pl_multi_process_test
|
||||
|
||||
|
||||
@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
|
||||
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
|
||||
def test_tpu_device_absence():
|
||||
"""Check tpu_device_exists returns None when torch_xla is not available"""
|
||||
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
|
||||
|
|
Loading…
Reference in New Issue