Cast hparams to dict when not using omegaconf (#4770)
* init fix * init test * more specific dict assert * update changelog * Update tests/checkpointing/test_model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
4803f681b0
commit
42e59c6add
23
CHANGELOG.md
23
CHANGELOG.md
|
@ -60,6 +60,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## [unreleased.BugFix] - YYYY-MM-DD
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Deprecated
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fixed checkpoint hparams dict casting when omegaconf is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770))
|
||||||
|
|
||||||
|
|
||||||
## [1.0.7] - 2020-11-17
|
## [1.0.7] - 2020-11-17
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
@ -328,10 +328,9 @@ class CheckpointConnector:
|
||||||
if hasattr(model, '_hparams_name'):
|
if hasattr(model, '_hparams_name'):
|
||||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
|
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
|
||||||
# dump arguments
|
# dump arguments
|
||||||
if OMEGACONF_AVAILABLE:
|
if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
|
||||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
|
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
|
||||||
if isinstance(model.hparams, Container):
|
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
|
||||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
|
|
||||||
else:
|
else:
|
||||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
|
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,8 @@ from pathlib import Path
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import Container, OmegaConf
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
import tests.base.develop_utils as tutils
|
import tests.base.develop_utils as tutils
|
||||||
from pytorch_lightning import Trainer, seed_everything
|
from pytorch_lightning import Trainer, seed_everything
|
||||||
|
@ -911,3 +913,35 @@ def test_current_score_when_nan(tmpdir, mode):
|
||||||
expected = float("inf" if mode == "min" else "-inf")
|
expected = float("inf" if mode == "min" else "-inf")
|
||||||
assert model_checkpoint.best_model_score == expected
|
assert model_checkpoint.best_model_score == expected
|
||||||
assert model_checkpoint.current_score == expected
|
assert model_checkpoint.current_score == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("hparams_type", [dict, Container])
|
||||||
|
def test_hparams_type(tmpdir, hparams_type):
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super().__init__()
|
||||||
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
model_checkpoint = ModelCheckpoint(
|
||||||
|
dirpath=tmpdir,
|
||||||
|
save_top_k=1,
|
||||||
|
monitor="foo",
|
||||||
|
)
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
fast_dev_run=True,
|
||||||
|
callbacks=[model_checkpoint],
|
||||||
|
logger=False,
|
||||||
|
weights_summary=None,
|
||||||
|
progress_bar_refresh_rate=0,
|
||||||
|
)
|
||||||
|
hp = {"test_hp_0": 1, "test_hp_1": 2}
|
||||||
|
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
|
||||||
|
model = TestModel(hp)
|
||||||
|
trainer.fit(model)
|
||||||
|
ckpt = trainer.checkpoint_connector.dump_checkpoint()
|
||||||
|
if hparams_type == Container:
|
||||||
|
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
|
||||||
|
else:
|
||||||
|
# make sure it's not AttributeDict
|
||||||
|
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type
|
||||||
|
|
Loading…
Reference in New Issue