Deprecate `LoggerCollection` in favor of `trainer.loggers` (#12147)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Akash Kwatra 2022-03-04 15:01:43 -08:00 committed by GitHub
parent 0db85d633c
commit 1f7298d326
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 58 additions and 21 deletions

View File

@ -454,6 +454,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
- Deprecated `LoggerCollection` in favor of `trainer.loggers` ([#12147](https://github.com/PyTorchLightning/pytorch-lightning/pull/12147))
- Deprecated `PrecisionPlugin.on_{save,load}_checkpoint` in favor of `PrecisionPlugin.{state_dict,load_state_dict}` ([#11978](https://github.com/PyTorchLightning/pytorch-lightning/pull/11978))

View File

@ -221,6 +221,10 @@ class LightningLoggerBase(ABC):
class LoggerCollection(LightningLoggerBase):
"""The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`.
.. deprecated:: v1.6
`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8.
Directly pass a list of loggers to the Trainer and access the list via the `trainer.loggers` attribute.
Args:
logger_iterable: An iterable collection of loggers
"""
@ -228,6 +232,10 @@ class LoggerCollection(LightningLoggerBase):
def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
super().__init__()
self._logger_iterable = logger_iterable
rank_zero_deprecation(
"`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. Directly pass a list of loggers"
" to the Trainer and access the list via the `trainer.loggers` attribute."
)
def __getitem__(self, index: int) -> LightningLoggerBase:
return list(self._logger_iterable)[index]

View File

@ -2623,6 +2623,8 @@ class Trainer(
" This behavior will change in v1.8 when LoggerCollection is removed, and"
" trainer.logger will return the first logger in trainer.loggers"
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return LoggerCollection(self.loggers)
@logger.setter

View File

@ -21,7 +21,7 @@ import torch
from torch import optim
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
@ -662,6 +662,23 @@ def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list)
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
def test_v1_8_0_logger_collection(tmpdir):
logger1 = CSVLogger(tmpdir)
logger2 = CSVLogger(tmpdir)
trainer1 = Trainer(logger=logger1)
trainer2 = Trainer(logger=[logger1, logger2])
# Should have no deprecation warning
trainer1.logger
trainer1.loggers
trainer2.loggers
trainer2.logger
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
LoggerCollection([logger1, logger2])
def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
class PrecisionPluginSaveHook(PrecisionPlugin):
def on_save_checkpoint(self, checkpoint):

View File

@ -34,6 +34,7 @@ def test_logger_collection():
mock1 = MagicMock()
mock2 = MagicMock()
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([mock1, mock2])
assert logger[0] == mock1
@ -62,6 +63,7 @@ def test_logger_collection_unique_names():
logger1 = CustomLogger(name=unique_name)
logger2 = CustomLogger(name=unique_name)
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([logger1, logger2])
assert logger.name == unique_name
@ -69,6 +71,7 @@ def test_logger_collection_unique_names():
def test_logger_collection_names_order():
loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")]
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection(loggers)
assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"
@ -78,6 +81,7 @@ def test_logger_collection_unique_versions():
logger1 = CustomLogger(version=unique_version)
logger2 = CustomLogger(version=unique_version)
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([logger1, logger2])
assert logger.version == unique_version
@ -85,6 +89,7 @@ def test_logger_collection_unique_versions():
def test_logger_collection_versions_order():
loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")]
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection(loggers)
assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"

View File

@ -24,7 +24,7 @@ import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -450,9 +450,9 @@ def test_pytorch_profiler_nested(tmpdir):
assert events_name == expected, (events_name, torch.__version__, platform.system())
def test_pytorch_profiler_logger_collection(tmpdir):
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer's logger is an
instance of LoggerCollection.
def test_pytorch_profiler_multiple_loggers(tmpdir):
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer is configured with
multiple loggers.
See issue #8157.
"""
@ -465,10 +465,9 @@ def test_pytorch_profiler_logger_collection(tmpdir):
assert not look_for_trace(tmpdir)
model = BoringModel()
# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
assert isinstance(trainer.logger, LoggerCollection)
loggers = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=loggers, limit_train_batches=5, max_epochs=1)
assert len(trainer.loggers) == 2
trainer.fit(model)
assert look_for_trace(tmpdir)

View File

@ -15,7 +15,7 @@ import os
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from tests.helpers.boring_model import BoringModel
@ -109,8 +109,8 @@ def test_logdir_custom_logger(tmpdir):
assert trainer.log_dir == expected
def test_logdir_logger_collection(tmpdir):
"""Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection."""
def test_logdir_multiple_loggers(tmpdir):
"""Tests that the logdir equals the default_root_dir when trainer has multiple loggers."""
default_root_dir = tmpdir / "default_root_dir"
save_dir = tmpdir / "save_dir"
model = TestModel(default_root_dir)
@ -119,7 +119,6 @@ def test_logdir_logger_collection(tmpdir):
max_steps=2,
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)],
)
assert isinstance(trainer.logger, LoggerCollection)
assert trainer.log_dir == default_root_dir
trainer.fit(model)

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from tests.loggers.test_base import CustomLogger
@ -50,7 +52,9 @@ def test_trainer_loggers_setters():
"""Test the behavior of setters for trainer.logger and trainer.loggers."""
logger1 = CustomLogger()
logger2 = CustomLogger()
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger_collection = LoggerCollection([logger1, logger2])
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger_collection_2 = LoggerCollection([logger2])
trainer = Trainer()