Return only unique names/versions for LoggerCollection (#10976)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
twsl 2021-12-23 01:35:38 +01:00 committed by GitHub
parent 576a5d62a0
commit 0b9034baef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 8 deletions

View File

@ -112,6 +112,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934))
- `LoggerCollection` returns only unique logger names and versions ([#10976](https://github.com/PyTorchLightning/pytorch-lightning/pull/10976))
- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts

View File

@ -452,13 +452,15 @@ class LoggerCollection(LightningLoggerBase):
@property
def name(self) -> str:
"""Returns the experiment names for all the loggers in the logger collection joined by an underscore."""
return "_".join(str(logger.name) for logger in self._logger_iterable)
"""Returns the unique experiment names for all the loggers in the logger collection joined by an
underscore."""
return "_".join(dict.fromkeys(str(logger.name) for logger in self._logger_iterable))
@property
def version(self) -> str:
"""Returns the experiment versions for all the loggers in the logger collection joined by an underscore."""
return "_".join(str(logger.version) for logger in self._logger_iterable)
"""Returns the unique experiment versions for all the loggers in the logger collection joined by an
underscore."""
return "_".join(dict.fromkeys(str(logger.version) for logger in self._logger_iterable))
class DummyExperiment:

View File

@ -56,9 +56,44 @@ def test_logger_collection():
mock2.finalize.assert_called_once()
def test_logger_collection_unique_names():
unique_name = "name1"
logger1 = CustomLogger(name=unique_name)
logger2 = CustomLogger(name=unique_name)
logger = LoggerCollection([logger1, logger2])
assert logger.name == unique_name
def test_logger_collection_names_order():
loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")]
logger = LoggerCollection(loggers)
assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"
def test_logger_collection_unique_versions():
unique_version = "1"
logger1 = CustomLogger(version=unique_version)
logger2 = CustomLogger(version=unique_version)
logger = LoggerCollection([logger1, logger2])
assert logger.version == unique_version
def test_logger_collection_versions_order():
loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")]
logger = LoggerCollection(loggers)
assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"
class CustomLogger(LightningLoggerBase):
def __init__(self):
def __init__(self, experiment: str = "test", name: str = "name", version: str = "1"):
super().__init__()
self._experiment = experiment
self._name = name
self._version = version
self.hparams_logged = None
self.metrics_logged = {}
self.finalized = False
@ -66,7 +101,7 @@ class CustomLogger(LightningLoggerBase):
@property
def experiment(self):
return "test"
return self._experiment
@rank_zero_only
def log_hyperparams(self, params):
@ -88,11 +123,11 @@ class CustomLogger(LightningLoggerBase):
@property
def name(self):
return "name"
return self._name
@property
def version(self):
return "1"
return self._version
def after_save_checkpoint(self, checkpoint_callback):
self.after_save_checkpoint_called = True