Support **DictConfig hparam serialization (#2519)
* change to OmegaConf API Co-authored-by: Omry Yadan <omry@fb.com> * Swapped Container for OmegaConf sentinel; Limited ds copying * Add Namespace check. * Container removed. Pass local tests. Co-authored-by: Omry Yadan <omry@fb.com>
This commit is contained in:
parent
a46130cdc1
commit
f9d88f8088
|
@ -16,11 +16,9 @@ from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
|
|||
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
||||
try:
|
||||
from omegaconf import Container
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
OMEGACONF_AVAILABLE = False
|
||||
else:
|
||||
OMEGACONF_AVAILABLE = True
|
||||
OmegaConf = None
|
||||
|
||||
# the older shall be on the top
|
||||
CHECKPOINT_PAST_HPARAMS_KEYS = (
|
||||
|
@ -330,23 +328,27 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
|||
if not gfile.isdir(os.path.dirname(config_yaml)):
|
||||
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
|
||||
|
||||
if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
OmegaConf.save(hparams, config_yaml, resolve=True)
|
||||
return
|
||||
|
||||
# saving the standard way
|
||||
# convert Namespace or AD to dict
|
||||
if isinstance(hparams, Namespace):
|
||||
hparams = vars(hparams)
|
||||
elif isinstance(hparams, AttributeDict):
|
||||
hparams = dict(hparams)
|
||||
|
||||
# saving with OmegaConf objects
|
||||
if OmegaConf is not None:
|
||||
if OmegaConf.is_config(hparams):
|
||||
OmegaConf.save(hparams, config_yaml, resolve=True)
|
||||
return
|
||||
for v in hparams.values():
|
||||
if OmegaConf.is_config(v):
|
||||
OmegaConf.save(OmegaConf.create(hparams), config_yaml, resolve=True)
|
||||
return
|
||||
|
||||
# saving the standard way
|
||||
assert isinstance(hparams, dict)
|
||||
|
||||
with cloud_open(config_yaml, "w", newline="") as fp:
|
||||
with open(config_yaml, 'w', newline='') as fp:
|
||||
yaml.dump(hparams, fp)
|
||||
|
||||
|
||||
def convert(val: str) -> Union[int, float, bool, str]:
|
||||
try:
|
||||
return ast.literal_eval(val)
|
||||
|
|
Loading…
Reference in New Issue