diff --git a/CHANGELOG.md b/CHANGELOG.md index d9ea2f29ce..bb326232fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,6 +112,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934)) +- `LoggerCollection` returns only unique logger names and versions ([#10976](https://github.com/PyTorchLightning/pytorch-lightning/pull/10976)) + + - Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) * All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` * The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 0698a40976..0cbbc134ca 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -452,13 +452,15 @@ class LoggerCollection(LightningLoggerBase): @property def name(self) -> str: - """Returns the experiment names for all the loggers in the logger collection joined by an underscore.""" - return "_".join(str(logger.name) for logger in self._logger_iterable) + """Returns the unique experiment names for all the loggers in the logger collection joined by an + underscore.""" + return "_".join(dict.fromkeys(str(logger.name) for logger in self._logger_iterable)) @property def version(self) -> str: - """Returns the experiment versions for all the loggers in the logger collection joined by an underscore.""" - return "_".join(str(logger.version) for logger in self._logger_iterable) + """Returns the unique experiment versions for all the loggers in the logger collection joined by an + underscore.""" + return "_".join(dict.fromkeys(str(logger.version) for logger in self._logger_iterable)) class DummyExperiment: diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 224271709f..cb5721d6a9 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -56,9 +56,44 @@ def test_logger_collection(): mock2.finalize.assert_called_once() +def test_logger_collection_unique_names(): + unique_name = "name1" + logger1 = CustomLogger(name=unique_name) + logger2 = CustomLogger(name=unique_name) + + logger = LoggerCollection([logger1, logger2]) + + assert logger.name == unique_name + + +def test_logger_collection_names_order(): + loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")] + logger = LoggerCollection(loggers) + assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}" + + +def test_logger_collection_unique_versions(): + unique_version = "1" + logger1 = CustomLogger(version=unique_version) + logger2 = CustomLogger(version=unique_version) + + logger = LoggerCollection([logger1, logger2]) + + assert logger.version == unique_version + + +def test_logger_collection_versions_order(): + loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")] + logger = LoggerCollection(loggers) + assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}" + + class CustomLogger(LightningLoggerBase): - def __init__(self): + def __init__(self, experiment: str = "test", name: str = "name", version: str = "1"): super().__init__() + self._experiment = experiment + self._name = name + self._version = version self.hparams_logged = None self.metrics_logged = {} self.finalized = False @@ -66,7 +101,7 @@ class CustomLogger(LightningLoggerBase): @property def experiment(self): - return "test" + return self._experiment @rank_zero_only def log_hyperparams(self, params): @@ -88,11 +123,11 @@ class CustomLogger(LightningLoggerBase): @property def name(self): - return "name" + return self._name @property def version(self): - return "1" + return self._version def after_save_checkpoint(self, checkpoint_callback): self.after_save_checkpoint_called = True