Fix `trainer.logger` deprecation message (#12671)
This commit is contained in:
parent
70754bea83
commit
10c7a7c84f
|
@ -36,10 +36,10 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
|||
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.saving import ModelIO
|
||||
from pytorch_lightning.loggers import Logger
|
||||
from pytorch_lightning.loggers import Logger, LoggerCollection
|
||||
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
||||
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
|
||||
|
@ -245,7 +245,26 @@ class LightningModule(
|
|||
@property
|
||||
def logger(self) -> Optional[Logger]:
|
||||
"""Reference to the logger object in the Trainer."""
|
||||
return self.trainer.logger if self.trainer else None
|
||||
# this should match the implementation of `trainer.logger`
|
||||
# we don't reuse it so we can properly set the deprecation stacklevel
|
||||
if self.trainer is None:
|
||||
return
|
||||
loggers = self.trainer.loggers
|
||||
if len(loggers) == 0:
|
||||
return None
|
||||
if len(loggers) == 1:
|
||||
return loggers[0]
|
||||
else:
|
||||
if not self._running_torchscript:
|
||||
rank_zero_deprecation(
|
||||
"Using `lightning_module.logger` when multiple loggers are configured."
|
||||
" This behavior will change in v1.8 when `LoggerCollection` is removed, and"
|
||||
" `lightning_module.logger` will return the first logger available.",
|
||||
stacklevel=5,
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return LoggerCollection(loggers)
|
||||
|
||||
@property
|
||||
def loggers(self) -> List[Logger]:
|
||||
|
|
|
@ -2624,19 +2624,21 @@ class Trainer(
|
|||
|
||||
@property
|
||||
def logger(self) -> Optional[Logger]:
|
||||
if len(self.loggers) == 0:
|
||||
loggers = self.loggers
|
||||
if len(loggers) == 0:
|
||||
return None
|
||||
if len(self.loggers) == 1:
|
||||
return self.loggers[0]
|
||||
if len(loggers) == 1:
|
||||
return loggers[0]
|
||||
else:
|
||||
rank_zero_warn(
|
||||
"Using trainer.logger when Trainer is configured to use multiple loggers."
|
||||
" This behavior will change in v1.8 when LoggerCollection is removed, and"
|
||||
" trainer.logger will return the first logger in trainer.loggers"
|
||||
rank_zero_deprecation(
|
||||
"Using `trainer.logger` when multiple loggers are configured."
|
||||
" This behavior will change in v1.8 when `LoggerCollection` is removed, and"
|
||||
" `trainer.logger` will return the first logger available.",
|
||||
stacklevel=5,
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return LoggerCollection(self.loggers)
|
||||
return LoggerCollection(loggers)
|
||||
|
||||
@logger.setter
|
||||
def logger(self, logger: Optional[Logger]) -> None:
|
||||
|
|
|
@ -77,7 +77,7 @@ def test_property_logger(tmpdir):
|
|||
assert model.logger is None
|
||||
|
||||
logger = TensorBoardLogger(tmpdir)
|
||||
trainer = Mock(logger=logger)
|
||||
trainer = Mock(loggers=[logger])
|
||||
model.trainer = trainer
|
||||
assert model.logger == logger
|
||||
|
||||
|
|
|
@ -758,10 +758,11 @@ def test_v1_8_0_logger_collection(tmpdir):
|
|||
trainer1.logger
|
||||
trainer1.loggers
|
||||
trainer2.loggers
|
||||
trainer2.logger
|
||||
|
||||
with pytest.deprecated_call(match="logger` will return the first logger"):
|
||||
_ = trainer2.logger
|
||||
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
|
||||
LoggerCollection([logger1, logger2])
|
||||
_ = LoggerCollection([logger1, logger2])
|
||||
|
||||
|
||||
def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
|
||||
|
|
|
@ -61,7 +61,8 @@ def test_trainer_loggers_setters():
|
|||
assert trainer.loggers == [logger1]
|
||||
|
||||
trainer.logger = logger_collection
|
||||
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
|
||||
with pytest.deprecated_call(match="logger` when multiple loggers are configured"):
|
||||
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
|
||||
assert trainer.loggers == [logger1, logger2]
|
||||
|
||||
# LoggerCollection of size 1 should result in trainer.logger becoming the contained logger.
|
||||
|
@ -76,7 +77,8 @@ def test_trainer_loggers_setters():
|
|||
# Test setters for trainer.loggers
|
||||
trainer.loggers = [logger1, logger2]
|
||||
assert trainer.loggers == [logger1, logger2]
|
||||
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
|
||||
with pytest.deprecated_call(match="logger` when multiple loggers are configured"):
|
||||
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
|
||||
|
||||
trainer.loggers = [logger1]
|
||||
assert trainer.loggers == [logger1]
|
||||
|
|
Loading…
Reference in New Issue