Support **DictConfig hparam serialization (#2519)
* change to OmegaConf API Co-authored-by: Omry Yadan <omry@fb.com> * Swapped Container for OmegaConf sentinel; Limited ds copying * Add Namespace check. * Container removed. Pass local tests. Co-authored-by: Omry Yadan <omry@fb.com>
This commit is contained in:
parent
a46130cdc1
commit
f9d88f8088
|
@ -16,11 +16,9 @@ from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
|
||||||
PRIMITIVE_TYPES = (bool, int, float, str)
|
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||||
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
||||||
try:
|
try:
|
||||||
from omegaconf import Container
|
from omegaconf import OmegaConf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
OMEGACONF_AVAILABLE = False
|
OmegaConf = None
|
||||||
else:
|
|
||||||
OMEGACONF_AVAILABLE = True
|
|
||||||
|
|
||||||
# the older shall be on the top
|
# the older shall be on the top
|
||||||
CHECKPOINT_PAST_HPARAMS_KEYS = (
|
CHECKPOINT_PAST_HPARAMS_KEYS = (
|
||||||
|
@ -330,23 +328,27 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
||||||
if not gfile.isdir(os.path.dirname(config_yaml)):
|
if not gfile.isdir(os.path.dirname(config_yaml)):
|
||||||
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
|
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
|
||||||
|
|
||||||
if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
|
# convert Namespace or AD to dict
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
OmegaConf.save(hparams, config_yaml, resolve=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
# saving the standard way
|
|
||||||
if isinstance(hparams, Namespace):
|
if isinstance(hparams, Namespace):
|
||||||
hparams = vars(hparams)
|
hparams = vars(hparams)
|
||||||
elif isinstance(hparams, AttributeDict):
|
elif isinstance(hparams, AttributeDict):
|
||||||
hparams = dict(hparams)
|
hparams = dict(hparams)
|
||||||
|
|
||||||
|
# saving with OmegaConf objects
|
||||||
|
if OmegaConf is not None:
|
||||||
|
if OmegaConf.is_config(hparams):
|
||||||
|
OmegaConf.save(hparams, config_yaml, resolve=True)
|
||||||
|
return
|
||||||
|
for v in hparams.values():
|
||||||
|
if OmegaConf.is_config(v):
|
||||||
|
OmegaConf.save(OmegaConf.create(hparams), config_yaml, resolve=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# saving the standard way
|
||||||
assert isinstance(hparams, dict)
|
assert isinstance(hparams, dict)
|
||||||
|
with open(config_yaml, 'w', newline='') as fp:
|
||||||
with cloud_open(config_yaml, "w", newline="") as fp:
|
|
||||||
yaml.dump(hparams, fp)
|
yaml.dump(hparams, fp)
|
||||||
|
|
||||||
|
|
||||||
def convert(val: str) -> Union[int, float, bool, str]:
|
def convert(val: str) -> Union[int, float, bool, str]:
|
||||||
try:
|
try:
|
||||||
return ast.literal_eval(val)
|
return ast.literal_eval(val)
|
||||||
|
|
Loading…
Reference in New Issue