Fix `trainer.logger` deprecation message (#12671)

This commit is contained in:
Carlos Mocholí 2022-04-27 16:11:34 +02:00 committed by GitHub
parent 70754bea83
commit 10c7a7c84f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 16 deletions

View File

@ -36,10 +36,10 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers import Logger, LoggerCollection
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
@ -245,7 +245,26 @@ class LightningModule(
@property
def logger(self) -> Optional[Logger]:
"""Reference to the logger object in the Trainer."""
return self.trainer.logger if self.trainer else None
# this should match the implementation of `trainer.logger`
# we don't reuse it so we can properly set the deprecation stacklevel
if self.trainer is None:
return
loggers = self.trainer.loggers
if len(loggers) == 0:
return None
if len(loggers) == 1:
return loggers[0]
else:
if not self._running_torchscript:
rank_zero_deprecation(
"Using `lightning_module.logger` when multiple loggers are configured."
" This behavior will change in v1.8 when `LoggerCollection` is removed, and"
" `lightning_module.logger` will return the first logger available.",
stacklevel=5,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return LoggerCollection(loggers)
@property
def loggers(self) -> List[Logger]:

View File

@ -2624,19 +2624,21 @@ class Trainer(
@property
def logger(self) -> Optional[Logger]:
if len(self.loggers) == 0:
loggers = self.loggers
if len(loggers) == 0:
return None
if len(self.loggers) == 1:
return self.loggers[0]
if len(loggers) == 1:
return loggers[0]
else:
rank_zero_warn(
"Using trainer.logger when Trainer is configured to use multiple loggers."
" This behavior will change in v1.8 when LoggerCollection is removed, and"
" trainer.logger will return the first logger in trainer.loggers"
rank_zero_deprecation(
"Using `trainer.logger` when multiple loggers are configured."
" This behavior will change in v1.8 when `LoggerCollection` is removed, and"
" `trainer.logger` will return the first logger available.",
stacklevel=5,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return LoggerCollection(self.loggers)
return LoggerCollection(loggers)
@logger.setter
def logger(self, logger: Optional[Logger]) -> None:

View File

@ -77,7 +77,7 @@ def test_property_logger(tmpdir):
assert model.logger is None
logger = TensorBoardLogger(tmpdir)
trainer = Mock(logger=logger)
trainer = Mock(loggers=[logger])
model.trainer = trainer
assert model.logger == logger

View File

@ -758,10 +758,11 @@ def test_v1_8_0_logger_collection(tmpdir):
trainer1.logger
trainer1.loggers
trainer2.loggers
trainer2.logger
with pytest.deprecated_call(match="logger` will return the first logger"):
_ = trainer2.logger
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
LoggerCollection([logger1, logger2])
_ = LoggerCollection([logger1, logger2])
def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):

View File

@ -61,7 +61,8 @@ def test_trainer_loggers_setters():
assert trainer.loggers == [logger1]
trainer.logger = logger_collection
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
with pytest.deprecated_call(match="logger` when multiple loggers are configured"):
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
assert trainer.loggers == [logger1, logger2]
# LoggerCollection of size 1 should result in trainer.logger becoming the contained logger.
@ -76,7 +77,8 @@ def test_trainer_loggers_setters():
# Test setters for trainer.loggers
trainer.loggers = [logger1, logger2]
assert trainer.loggers == [logger1, logger2]
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
with pytest.deprecated_call(match="logger` when multiple loggers are configured"):
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
trainer.loggers = [logger1]
assert trainer.loggers == [logger1]