Cast hparams to dict when not using omegaconf (#4770)
* init fix * init test * more specific dict assert * update changelog * Update tests/checkpointing/test_model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
4803f681b0
commit
42e59c6add
23
CHANGELOG.md
23
CHANGELOG.md
|
@ -60,6 +60,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
|
||||
|
||||
## [unreleased.BugFix] - YYYY-MM-DD
|
||||
|
||||
### Added
|
||||
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed checkpoint hparams dict casting when omegaconf is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770))
|
||||
|
||||
|
||||
## [1.0.7] - 2020-11-17
|
||||
|
||||
### Added
|
||||
|
|
|
@ -328,10 +328,9 @@ class CheckpointConnector:
|
|||
if hasattr(model, '_hparams_name'):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
|
||||
# dump arguments
|
||||
if OMEGACONF_AVAILABLE:
|
||||
if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
|
||||
if isinstance(model.hparams, Container):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
|
||||
else:
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
|
||||
|
||||
|
|
|
@ -27,6 +27,8 @@ from pathlib import Path
|
|||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
from omegaconf import Container, OmegaConf
|
||||
from argparse import Namespace
|
||||
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
|
@ -911,3 +913,35 @@ def test_current_score_when_nan(tmpdir, mode):
|
|||
expected = float("inf" if mode == "min" else "-inf")
|
||||
assert model_checkpoint.best_model_score == expected
|
||||
assert model_checkpoint.current_score == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hparams_type", [dict, Container])
|
||||
def test_hparams_type(tmpdir, hparams_type):
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
model_checkpoint = ModelCheckpoint(
|
||||
dirpath=tmpdir,
|
||||
save_top_k=1,
|
||||
monitor="foo",
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
callbacks=[model_checkpoint],
|
||||
logger=False,
|
||||
weights_summary=None,
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
hp = {"test_hp_0": 1, "test_hp_1": 2}
|
||||
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
|
||||
model = TestModel(hp)
|
||||
trainer.fit(model)
|
||||
ckpt = trainer.checkpoint_connector.dump_checkpoint()
|
||||
if hparams_type == Container:
|
||||
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
|
||||
else:
|
||||
# make sure it's not AttributeDict
|
||||
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type
|
||||
|
|
Loading…
Reference in New Issue