Return only unique names/versions for LoggerCollection (#10976)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
576a5d62a0
commit
0b9034baef
|
@ -112,6 +112,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934))
|
||||
|
||||
|
||||
- `LoggerCollection` returns only unique logger names and versions ([#10976](https://github.com/PyTorchLightning/pytorch-lightning/pull/10976))
|
||||
|
||||
|
||||
- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
|
||||
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
|
||||
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts
|
||||
|
|
|
@ -452,13 +452,15 @@ class LoggerCollection(LightningLoggerBase):
|
|||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Returns the experiment names for all the loggers in the logger collection joined by an underscore."""
|
||||
return "_".join(str(logger.name) for logger in self._logger_iterable)
|
||||
"""Returns the unique experiment names for all the loggers in the logger collection joined by an
|
||||
underscore."""
|
||||
return "_".join(dict.fromkeys(str(logger.name) for logger in self._logger_iterable))
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Returns the experiment versions for all the loggers in the logger collection joined by an underscore."""
|
||||
return "_".join(str(logger.version) for logger in self._logger_iterable)
|
||||
"""Returns the unique experiment versions for all the loggers in the logger collection joined by an
|
||||
underscore."""
|
||||
return "_".join(dict.fromkeys(str(logger.version) for logger in self._logger_iterable))
|
||||
|
||||
|
||||
class DummyExperiment:
|
||||
|
|
|
@ -56,9 +56,44 @@ def test_logger_collection():
|
|||
mock2.finalize.assert_called_once()
|
||||
|
||||
|
||||
def test_logger_collection_unique_names():
|
||||
unique_name = "name1"
|
||||
logger1 = CustomLogger(name=unique_name)
|
||||
logger2 = CustomLogger(name=unique_name)
|
||||
|
||||
logger = LoggerCollection([logger1, logger2])
|
||||
|
||||
assert logger.name == unique_name
|
||||
|
||||
|
||||
def test_logger_collection_names_order():
|
||||
loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")]
|
||||
logger = LoggerCollection(loggers)
|
||||
assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"
|
||||
|
||||
|
||||
def test_logger_collection_unique_versions():
|
||||
unique_version = "1"
|
||||
logger1 = CustomLogger(version=unique_version)
|
||||
logger2 = CustomLogger(version=unique_version)
|
||||
|
||||
logger = LoggerCollection([logger1, logger2])
|
||||
|
||||
assert logger.version == unique_version
|
||||
|
||||
|
||||
def test_logger_collection_versions_order():
|
||||
loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")]
|
||||
logger = LoggerCollection(loggers)
|
||||
assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"
|
||||
|
||||
|
||||
class CustomLogger(LightningLoggerBase):
|
||||
def __init__(self):
|
||||
def __init__(self, experiment: str = "test", name: str = "name", version: str = "1"):
|
||||
super().__init__()
|
||||
self._experiment = experiment
|
||||
self._name = name
|
||||
self._version = version
|
||||
self.hparams_logged = None
|
||||
self.metrics_logged = {}
|
||||
self.finalized = False
|
||||
|
@ -66,7 +101,7 @@ class CustomLogger(LightningLoggerBase):
|
|||
|
||||
@property
|
||||
def experiment(self):
|
||||
return "test"
|
||||
return self._experiment
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params):
|
||||
|
@ -88,11 +123,11 @@ class CustomLogger(LightningLoggerBase):
|
|||
|
||||
@property
|
||||
def name(self):
|
||||
return "name"
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return "1"
|
||||
return self._version
|
||||
|
||||
def after_save_checkpoint(self, checkpoint_callback):
|
||||
self.after_save_checkpoint_called = True
|
||||
|
|
Loading…
Reference in New Issue