diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a4f343fe68..70017c8e79 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -36,10 +36,10 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO -from pytorch_lightning.loggers import Logger +from pytorch_lightning.loggers import Logger, LoggerCollection from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator -from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType +from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp @@ -245,7 +245,26 @@ class LightningModule( @property def logger(self) -> Optional[Logger]: """Reference to the logger object in the Trainer.""" - return self.trainer.logger if self.trainer else None + # this should match the implementation of `trainer.logger` + # we don't reuse it so we can properly set the deprecation stacklevel + if self.trainer is None: + return + loggers = self.trainer.loggers + if len(loggers) == 0: + return None + if len(loggers) == 1: + return loggers[0] + else: + if not self._running_torchscript: + rank_zero_deprecation( + "Using `lightning_module.logger` when multiple loggers are configured." + " This behavior will change in v1.8 when `LoggerCollection` is removed, and" + " `lightning_module.logger` will return the first logger available.", + stacklevel=5, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return LoggerCollection(loggers) @property def loggers(self) -> List[Logger]: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 70e3e3bfc8..384e1f43b3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2624,19 +2624,21 @@ class Trainer( @property def logger(self) -> Optional[Logger]: - if len(self.loggers) == 0: + loggers = self.loggers + if len(loggers) == 0: return None - if len(self.loggers) == 1: - return self.loggers[0] + if len(loggers) == 1: + return loggers[0] else: - rank_zero_warn( - "Using trainer.logger when Trainer is configured to use multiple loggers." - " This behavior will change in v1.8 when LoggerCollection is removed, and" - " trainer.logger will return the first logger in trainer.loggers" + rank_zero_deprecation( + "Using `trainer.logger` when multiple loggers are configured." + " This behavior will change in v1.8 when `LoggerCollection` is removed, and" + " `trainer.logger` will return the first logger available.", + stacklevel=5, ) with warnings.catch_warnings(): warnings.simplefilter("ignore") - return LoggerCollection(self.loggers) + return LoggerCollection(loggers) @logger.setter def logger(self, logger: Optional[Logger]) -> None: diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index c7fee3b0d5..07fcf8dadc 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -77,7 +77,7 @@ def test_property_logger(tmpdir): assert model.logger is None logger = TensorBoardLogger(tmpdir) - trainer = Mock(logger=logger) + trainer = Mock(loggers=[logger]) model.trainer = trainer assert model.logger == logger diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 5a46411c83..f168cfcd12 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -758,10 +758,11 @@ def test_v1_8_0_logger_collection(tmpdir): trainer1.logger trainer1.loggers trainer2.loggers - trainer2.logger + with pytest.deprecated_call(match="logger` will return the first logger"): + _ = trainer2.logger with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): - LoggerCollection([logger1, logger2]) + _ = LoggerCollection([logger1, logger2]) def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir): diff --git a/tests/trainer/properties/test_loggers.py b/tests/trainer/properties/test_loggers.py index a7efe71ddb..ac3a01cba3 100644 --- a/tests/trainer/properties/test_loggers.py +++ b/tests/trainer/properties/test_loggers.py @@ -61,7 +61,8 @@ def test_trainer_loggers_setters(): assert trainer.loggers == [logger1] trainer.logger = logger_collection - assert trainer.logger._logger_iterable == logger_collection._logger_iterable + with pytest.deprecated_call(match="logger` when multiple loggers are configured"): + assert trainer.logger._logger_iterable == logger_collection._logger_iterable assert trainer.loggers == [logger1, logger2] # LoggerCollection of size 1 should result in trainer.logger becoming the contained logger. @@ -76,7 +77,8 @@ def test_trainer_loggers_setters(): # Test setters for trainer.loggers trainer.loggers = [logger1, logger2] assert trainer.loggers == [logger1, logger2] - assert trainer.logger._logger_iterable == logger_collection._logger_iterable + with pytest.deprecated_call(match="logger` when multiple loggers are configured"): + assert trainer.logger._logger_iterable == logger_collection._logger_iterable trainer.loggers = [logger1] assert trainer.loggers == [logger1]