diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index fc40db4e69..a27998366b 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -409,6 +409,11 @@ class DummyExperiment(object): def __getattr__(self, _): return self.nop + def __getitem__(self, idx): + # enables self.logger[0].experiment.add_image + # and self.logger.experiment[0].add_image(...) + return self + class DummyLogger(LightningLoggerBase): """ Dummy logger for internal use. Is usefull if we want to disable users @@ -437,6 +442,9 @@ class DummyLogger(LightningLoggerBase): def version(self): pass + def __getitem__(self, idx): + return self + def merge_dicts( dicts: Sequence[Mapping], diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 4b270927a4..3d89c8cd85 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -20,6 +20,7 @@ import numpy as np from pytorch_lightning import Trainer from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection +from pytorch_lightning.loggers.base import DummyLogger, DummyExperiment from pytorch_lightning.utilities import rank_zero_only from tests.base import EvalModelTemplate @@ -215,6 +216,16 @@ def test_with_accumulate_grad_batches(): assert logger.history == {0: {'loss': 0.5623850983416314}, 1: {'loss': 0.4778883735637184}} +def test_dummyexperiment_support_indexing(): + experiment = DummyExperiment() + assert experiment[0] == experiment + + +def test_dummylogger_support_indexing(): + logger = DummyLogger() + assert logger[0] == logger + + def test_np_sanitization(): class CustomParamsLogger(CustomLogger): def __init__(self):