Cast hparams to dict when not using omegaconf (#4770)

* init fix

* init test

* more specific dict assert

* update changelog

* Update tests/checkpointing/test_model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Roger Shieh 2020-11-20 19:53:05 +08:00 committed by GitHub
parent 4803f681b0
commit 42e59c6add
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 3 deletions

View File

@ -60,6 +60,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [unreleased.BugFix] - YYYY-MM-DD
### Added
### Changed
### Deprecated
### Removed
### Fixed
- Fixed checkpoint hparams dict casting when omegaconf is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770))
## [1.0.7] - 2020-11-17
### Added

View File

@ -328,10 +328,9 @@ class CheckpointConnector:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# dump arguments
if OMEGACONF_AVAILABLE:
if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
else:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

View File

@ -27,6 +27,8 @@ from pathlib import Path
import cloudpickle
import pytest
import torch
from omegaconf import Container, OmegaConf
from argparse import Namespace
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer, seed_everything
@ -911,3 +913,35 @@ def test_current_score_when_nan(tmpdir, mode):
expected = float("inf" if mode == "min" else "-inf")
assert model_checkpoint.best_model_score == expected
assert model_checkpoint.current_score == expected
@pytest.mark.parametrize("hparams_type", [dict, Container])
def test_hparams_type(tmpdir, hparams_type):
class TestModel(BoringModel):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=1,
monitor="foo",
)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
hp = {"test_hp_0": 1, "test_hp_1": 2}
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
model = TestModel(hp)
trainer.fit(model)
ckpt = trainer.checkpoint_connector.dump_checkpoint()
if hparams_type == Container:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type