[Bug Fix] Allow logger to support indexing (#4595)
* [Bug Fix] Allow logger to support indexing This should fix #4540 * Adding test for indexes for DummyLogger * Apply suggestions from code review Co-authored-by: chaton <thomas@grid.ai> * pep8 * added test for dummyexperiment Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
parent
16fa4ed1e5
commit
849737e7ca
|
@ -409,6 +409,11 @@ class DummyExperiment(object):
|
|||
def __getattr__(self, _):
|
||||
return self.nop
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# enables self.logger[0].experiment.add_image
|
||||
# and self.logger.experiment[0].add_image(...)
|
||||
return self
|
||||
|
||||
|
||||
class DummyLogger(LightningLoggerBase):
|
||||
""" Dummy logger for internal use. Is usefull if we want to disable users
|
||||
|
@ -437,6 +442,9 @@ class DummyLogger(LightningLoggerBase):
|
|||
def version(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self
|
||||
|
||||
|
||||
def merge_dicts(
|
||||
dicts: Sequence[Mapping],
|
||||
|
|
|
@ -20,6 +20,7 @@ import numpy as np
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection
|
||||
from pytorch_lightning.loggers.base import DummyLogger, DummyExperiment
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
@ -215,6 +216,16 @@ def test_with_accumulate_grad_batches():
|
|||
assert logger.history == {0: {'loss': 0.5623850983416314}, 1: {'loss': 0.4778883735637184}}
|
||||
|
||||
|
||||
def test_dummyexperiment_support_indexing():
|
||||
experiment = DummyExperiment()
|
||||
assert experiment[0] == experiment
|
||||
|
||||
|
||||
def test_dummylogger_support_indexing():
|
||||
logger = DummyLogger()
|
||||
assert logger[0] == logger
|
||||
|
||||
|
||||
def test_np_sanitization():
|
||||
class CustomParamsLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue