diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 823b983027..8082cbac3e 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -3,13 +3,11 @@ import functools import operator from abc import ABC, abstractmethod from argparse import Namespace -from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple +from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple, MutableMapping import numpy as np import torch -from pytorch_lightning.utilities import rank_zero_only - class LightningLoggerBase(ABC): """ @@ -174,9 +172,9 @@ class LightningLoggerBase(ABC): def _dict_generator(input_dict, prefixes=None): prefixes = prefixes[:] if prefixes else [] - if isinstance(input_dict, dict): + if isinstance(input_dict, MutableMapping): for key, value in input_dict.items(): - if isinstance(value, (dict, Namespace)): + if isinstance(value, (MutableMapping, Namespace)): value = vars(value) if isinstance(value, Namespace) else value for d in _dict_generator(value, prefixes + [key]): yield d