Deprecate `LoggerCollection` in favor of `trainer.loggers` (#12147)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
0db85d633c
commit
1f7298d326
|
@ -454,6 +454,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
|
||||
|
||||
|
||||
- Deprecated `LoggerCollection` in favor of `trainer.loggers` ([#12147](https://github.com/PyTorchLightning/pytorch-lightning/pull/12147))
|
||||
|
||||
|
||||
- Deprecated `PrecisionPlugin.on_{save,load}_checkpoint` in favor of `PrecisionPlugin.{state_dict,load_state_dict}` ([#11978](https://github.com/PyTorchLightning/pytorch-lightning/pull/11978))
|
||||
|
||||
|
||||
|
|
|
@ -221,6 +221,10 @@ class LightningLoggerBase(ABC):
|
|||
class LoggerCollection(LightningLoggerBase):
|
||||
"""The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`.
|
||||
|
||||
.. deprecated:: v1.6
|
||||
`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8.
|
||||
Directly pass a list of loggers to the Trainer and access the list via the `trainer.loggers` attribute.
|
||||
|
||||
Args:
|
||||
logger_iterable: An iterable collection of loggers
|
||||
"""
|
||||
|
@ -228,6 +232,10 @@ class LoggerCollection(LightningLoggerBase):
|
|||
def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
|
||||
super().__init__()
|
||||
self._logger_iterable = logger_iterable
|
||||
rank_zero_deprecation(
|
||||
"`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. Directly pass a list of loggers"
|
||||
" to the Trainer and access the list via the `trainer.loggers` attribute."
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> LightningLoggerBase:
|
||||
return list(self._logger_iterable)[index]
|
||||
|
|
|
@ -2623,6 +2623,8 @@ class Trainer(
|
|||
" This behavior will change in v1.8 when LoggerCollection is removed, and"
|
||||
" trainer.logger will return the first logger in trainer.loggers"
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return LoggerCollection(self.loggers)
|
||||
|
||||
@logger.setter
|
||||
|
|
|
@ -21,7 +21,7 @@ import torch
|
|||
from torch import optim
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase
|
||||
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
|
||||
|
@ -662,6 +662,23 @@ def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list)
|
|||
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
|
||||
|
||||
|
||||
def test_v1_8_0_logger_collection(tmpdir):
|
||||
logger1 = CSVLogger(tmpdir)
|
||||
logger2 = CSVLogger(tmpdir)
|
||||
|
||||
trainer1 = Trainer(logger=logger1)
|
||||
trainer2 = Trainer(logger=[logger1, logger2])
|
||||
|
||||
# Should have no deprecation warning
|
||||
trainer1.logger
|
||||
trainer1.loggers
|
||||
trainer2.loggers
|
||||
trainer2.logger
|
||||
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
LoggerCollection([logger1, logger2])
|
||||
|
||||
|
||||
def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
|
||||
class PrecisionPluginSaveHook(PrecisionPlugin):
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
|
|
@ -34,6 +34,7 @@ def test_logger_collection():
|
|||
mock1 = MagicMock()
|
||||
mock2 = MagicMock()
|
||||
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
logger = LoggerCollection([mock1, mock2])
|
||||
|
||||
assert logger[0] == mock1
|
||||
|
@ -62,6 +63,7 @@ def test_logger_collection_unique_names():
|
|||
logger1 = CustomLogger(name=unique_name)
|
||||
logger2 = CustomLogger(name=unique_name)
|
||||
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
logger = LoggerCollection([logger1, logger2])
|
||||
|
||||
assert logger.name == unique_name
|
||||
|
@ -69,6 +71,7 @@ def test_logger_collection_unique_names():
|
|||
|
||||
def test_logger_collection_names_order():
|
||||
loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")]
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
logger = LoggerCollection(loggers)
|
||||
assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"
|
||||
|
||||
|
@ -78,6 +81,7 @@ def test_logger_collection_unique_versions():
|
|||
logger1 = CustomLogger(version=unique_version)
|
||||
logger2 = CustomLogger(version=unique_version)
|
||||
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
logger = LoggerCollection([logger1, logger2])
|
||||
|
||||
assert logger.version == unique_version
|
||||
|
@ -85,6 +89,7 @@ def test_logger_collection_unique_versions():
|
|||
|
||||
def test_logger_collection_versions_order():
|
||||
loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")]
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
logger = LoggerCollection(loggers)
|
||||
assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch
|
|||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
|
||||
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
|
||||
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -450,9 +450,9 @@ def test_pytorch_profiler_nested(tmpdir):
|
|||
assert events_name == expected, (events_name, torch.__version__, platform.system())
|
||||
|
||||
|
||||
def test_pytorch_profiler_logger_collection(tmpdir):
|
||||
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer's logger is an
|
||||
instance of LoggerCollection.
|
||||
def test_pytorch_profiler_multiple_loggers(tmpdir):
|
||||
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer is configured with
|
||||
multiple loggers.
|
||||
|
||||
See issue #8157.
|
||||
"""
|
||||
|
@ -465,10 +465,9 @@ def test_pytorch_profiler_logger_collection(tmpdir):
|
|||
assert not look_for_trace(tmpdir)
|
||||
|
||||
model = BoringModel()
|
||||
# Wrap the logger in a list so it becomes a LoggerCollection
|
||||
logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
|
||||
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
loggers = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
|
||||
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=loggers, limit_train_batches=5, max_epochs=1)
|
||||
assert len(trainer.loggers) == 2
|
||||
trainer.fit(model)
|
||||
assert look_for_trace(tmpdir)
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ import os
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -109,8 +109,8 @@ def test_logdir_custom_logger(tmpdir):
|
|||
assert trainer.log_dir == expected
|
||||
|
||||
|
||||
def test_logdir_logger_collection(tmpdir):
|
||||
"""Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection."""
|
||||
def test_logdir_multiple_loggers(tmpdir):
|
||||
"""Tests that the logdir equals the default_root_dir when trainer has multiple loggers."""
|
||||
default_root_dir = tmpdir / "default_root_dir"
|
||||
save_dir = tmpdir / "save_dir"
|
||||
model = TestModel(default_root_dir)
|
||||
|
@ -119,7 +119,6 @@ def test_logdir_logger_collection(tmpdir):
|
|||
max_steps=2,
|
||||
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)],
|
||||
)
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
assert trainer.log_dir == default_root_dir
|
||||
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
from tests.loggers.test_base import CustomLogger
|
||||
|
@ -50,7 +52,9 @@ def test_trainer_loggers_setters():
|
|||
"""Test the behavior of setters for trainer.logger and trainer.loggers."""
|
||||
logger1 = CustomLogger()
|
||||
logger2 = CustomLogger()
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
logger_collection = LoggerCollection([logger1, logger2])
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
logger_collection_2 = LoggerCollection([logger2])
|
||||
|
||||
trainer = Trainer()
|
||||
|
|
Loading…
Reference in New Issue