Support **DictConfig hparam serialization (#2519)

* change to OmegaConf API

Co-authored-by: Omry Yadan <omry@fb.com>

* Swapped Container for OmegaConf sentinel; Limited ds copying

* Add Namespace check.

* Container removed. Pass local tests.

Co-authored-by: Omry Yadan <omry@fb.com>
This commit is contained in:
Rosario Scalise 2020-08-12 05:10:17 -07:00 committed by GitHub
parent a46130cdc1
commit f9d88f8088
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 14 deletions

View File

@ -16,11 +16,9 @@ from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
try:
from omegaconf import Container
from omegaconf import OmegaConf
except ImportError:
OMEGACONF_AVAILABLE = False
else:
OMEGACONF_AVAILABLE = True
OmegaConf = None
# the older shall be on the top
CHECKPOINT_PAST_HPARAMS_KEYS = (
@ -330,23 +328,27 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
if not gfile.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
from omegaconf import OmegaConf
OmegaConf.save(hparams, config_yaml, resolve=True)
return
# saving the standard way
# convert Namespace or AD to dict
if isinstance(hparams, Namespace):
hparams = vars(hparams)
elif isinstance(hparams, AttributeDict):
hparams = dict(hparams)
# saving with OmegaConf objects
if OmegaConf is not None:
if OmegaConf.is_config(hparams):
OmegaConf.save(hparams, config_yaml, resolve=True)
return
for v in hparams.values():
if OmegaConf.is_config(v):
OmegaConf.save(OmegaConf.create(hparams), config_yaml, resolve=True)
return
# saving the standard way
assert isinstance(hparams, dict)
with cloud_open(config_yaml, "w", newline="") as fp:
with open(config_yaml, 'w', newline='') as fp:
yaml.dump(hparams, fp)
def convert(val: str) -> Union[int, float, bool, str]:
try:
return ast.literal_eval(val)