past checkpoints (#2160)
* past checkpoints * omegaConf save * enforce type * resolve=True Co-authored-by: Omry Yadan <omry@fb.com> * test omegaconf * tests * test past Co-authored-by: Omry Yadan <omry@fb.com>
This commit is contained in:
parent
c826a5f599
commit
c0903b800d
|
@ -21,10 +21,17 @@ except ImportError:
|
|||
else:
|
||||
ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, )
|
||||
|
||||
# the older shall be on the top
|
||||
CHECKPOINT_PAST_HPARAMS_KEYS = (
|
||||
'hparams',
|
||||
'module_arguments', # used in 0.7.6
|
||||
)
|
||||
|
||||
|
||||
class ModelIO(object):
|
||||
CHECKPOINT_KEY_HYPER_PARAMS = 'hyper_parameters'
|
||||
CHECKPOINT_NAME_HYPER_PARAMS = 'hparams_name'
|
||||
CHECKPOINT_HYPER_PARAMS_KEY = 'hyper_parameters'
|
||||
CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name'
|
||||
CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type'
|
||||
|
||||
@classmethod
|
||||
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
|
||||
|
@ -153,10 +160,13 @@ class ModelIO(object):
|
|||
hparams['on_gpu'] = False
|
||||
|
||||
# overwrite hparams by the given file
|
||||
checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS] = hparams
|
||||
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
|
||||
|
||||
# override the module_arguments with values that were passed in
|
||||
checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS].update(kwargs)
|
||||
# for past checkpoint need to add the new key
|
||||
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
|
||||
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
|
||||
# override the hparams with values that were passed in
|
||||
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
|
||||
|
||||
model = cls._load_model_state(checkpoint, *args, **kwargs)
|
||||
return model
|
||||
|
@ -164,10 +174,15 @@ class ModelIO(object):
|
|||
@classmethod
|
||||
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
|
||||
# pass in the values we saved automatically
|
||||
if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
|
||||
# todo add some back compatibility
|
||||
model_args = checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS]
|
||||
args_name = checkpoint.get(cls.CHECKPOINT_NAME_HYPER_PARAMS)
|
||||
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
model_args = {}
|
||||
# add some back compatibility, the actual one shall be last
|
||||
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
|
||||
if hparam_key in checkpoint:
|
||||
model_args.update(checkpoint[hparam_key])
|
||||
if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
|
||||
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
|
||||
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
|
||||
init_args_name = inspect.signature(cls).parameters.keys()
|
||||
if args_name == 'kwargs':
|
||||
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}
|
||||
|
|
|
@ -4,7 +4,6 @@ TensorBoard
|
|||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from argparse import Namespace
|
||||
from typing import Optional, Dict, Union, Any
|
||||
from warnings import warn
|
||||
|
@ -18,6 +17,11 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
try:
|
||||
from omegaconf import Container
|
||||
except ImportError:
|
||||
Container = None
|
||||
|
||||
|
||||
class TensorBoardLogger(LightningLoggerBase):
|
||||
r"""
|
||||
|
@ -152,7 +156,14 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
|
||||
|
||||
# save the metatags file
|
||||
save_hparams_to_yaml(hparams_file, self.hparams)
|
||||
if Container is not None:
|
||||
if isinstance(self.hparams, Container):
|
||||
from omegaconf import OmegaConf
|
||||
OmegaConf.save(self.hparams, hparams_file, resolve=True)
|
||||
else:
|
||||
save_hparams_to_yaml(hparams_file, self.hparams)
|
||||
else:
|
||||
save_hparams_to_yaml(hparams_file, self.hparams)
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
|
|
|
@ -87,12 +87,12 @@ import os
|
|||
import re
|
||||
import signal
|
||||
from abc import ABC
|
||||
from argparse import Namespace
|
||||
from subprocess import call
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
|
||||
import pytorch_lightning
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
|
@ -100,7 +100,7 @@ from pytorch_lightning.overrides.data_parallel import (
|
|||
LightningDistributedDataParallel,
|
||||
LightningDataParallel,
|
||||
)
|
||||
from pytorch_lightning.utilities import rank_zero_warn, parsing
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.io import load as pl_load
|
||||
|
||||
try:
|
||||
|
@ -119,6 +119,11 @@ except ImportError:
|
|||
else:
|
||||
HOROVOD_AVAILABLE = True
|
||||
|
||||
try:
|
||||
from omegaconf import Container
|
||||
except ImportError:
|
||||
Container = None
|
||||
|
||||
|
||||
class TrainerIOMixin(ABC):
|
||||
|
||||
|
@ -267,8 +272,8 @@ class TrainerIOMixin(ABC):
|
|||
try:
|
||||
self._atomic_save(checkpoint, filepath)
|
||||
except AttributeError as err:
|
||||
if LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]
|
||||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.'
|
||||
f' An attribute is not picklable {err}')
|
||||
self._atomic_save(checkpoint, filepath)
|
||||
|
@ -320,6 +325,7 @@ class TrainerIOMixin(ABC):
|
|||
checkpoint = {
|
||||
'epoch': self.current_epoch + 1,
|
||||
'global_step': self.global_step + 1,
|
||||
'pytorch-ligthning_version': pytorch_lightning.__version__,
|
||||
}
|
||||
|
||||
if not weights_only:
|
||||
|
@ -356,10 +362,12 @@ class TrainerIOMixin(ABC):
|
|||
|
||||
if model.hparams:
|
||||
if hasattr(model, '_hparams_name'):
|
||||
checkpoint[LightningModule.CHECKPOINT_NAME_HYPER_PARAMS] = model._hparams_name
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
|
||||
# add arguments to the checkpoint
|
||||
# todo: add some recursion in case of OmegaConf
|
||||
checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS] = dict(model.hparams)
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
|
||||
if Container is not None:
|
||||
if isinstance(model.hparams, Container):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
|
||||
|
||||
# give the model a chance to add a few things
|
||||
model.on_save_checkpoint(checkpoint)
|
||||
|
@ -473,8 +481,8 @@ class TrainerIOMixin(ABC):
|
|||
try:
|
||||
self._atomic_save(checkpoint, filepath)
|
||||
except AttributeError as err:
|
||||
if LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]
|
||||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
rank_zero_warn('warning, `module_arguments` dropped from checkpoint.'
|
||||
f' An attribute is not picklable {err}')
|
||||
self._atomic_save(checkpoint, filepath)
|
||||
|
|
|
@ -30,31 +30,34 @@ class AssignHparamsModel(EvalModelTemplate):
|
|||
# -------------------------
|
||||
# STANDARD TESTS
|
||||
# -------------------------
|
||||
def _run_standard_hparams_test(tmpdir, model, cls):
|
||||
def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
|
||||
"""
|
||||
Tests for the existence of an arg 'test_arg=14'
|
||||
"""
|
||||
hparam_type = type(model.hparams)
|
||||
# test proper property assignments
|
||||
assert model.hparams.test_arg == 14
|
||||
|
||||
# verify we can train
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_pct=0.5)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure the raw checkpoint saved the properties
|
||||
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
||||
raw_checkpoint = torch.load(raw_checkpoint_path)
|
||||
assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint
|
||||
assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['test_arg'] == 14
|
||||
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
|
||||
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14
|
||||
|
||||
# verify that model loads correctly
|
||||
model = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
assert model.hparams.test_arg == 14
|
||||
model2 = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
assert model2.hparams.test_arg == 14
|
||||
|
||||
# todo
|
||||
# verify that we can overwrite the property
|
||||
# model = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
|
||||
# assert model.hparams.test_arg == 78
|
||||
assert isinstance(model2.hparams, hparam_type)
|
||||
|
||||
if try_overwrite:
|
||||
# verify that we can overwrite the property
|
||||
model3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
|
||||
assert model3.hparams.test_arg == 78
|
||||
|
||||
return raw_checkpoint_path
|
||||
|
||||
|
@ -82,14 +85,16 @@ def test_omega_conf_hparams(tmpdir, cls):
|
|||
# init model
|
||||
conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)]))
|
||||
model = cls(hparams=conf)
|
||||
assert isinstance(model.hparams, Container)
|
||||
|
||||
# run standard test suite
|
||||
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls)
|
||||
model = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
model2 = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
assert isinstance(model2.hparams, Container)
|
||||
|
||||
# config specific tests
|
||||
assert model.hparams.test_arg == 14
|
||||
assert model.hparams.mylist[0] == 15.4
|
||||
assert model2.hparams.test_arg == 14
|
||||
assert model2.hparams.mylist[0] == 15.4
|
||||
|
||||
|
||||
def test_explicit_args_hparams(tmpdir):
|
||||
|
@ -157,8 +162,8 @@ def test_explicit_missing_args_hparams(tmpdir):
|
|||
# make sure the raw checkpoint saved the properties
|
||||
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
||||
raw_checkpoint = torch.load(raw_checkpoint_path)
|
||||
assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint
|
||||
assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['test_arg'] == 14
|
||||
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
|
||||
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14
|
||||
|
||||
# verify that model loads correctly
|
||||
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
|
||||
|
@ -199,110 +204,91 @@ def test_class_nesting():
|
|||
A().test()
|
||||
|
||||
|
||||
@pytest.mark.xfail(sys.version_info >= (3, 6), reason='OmegaConf only for Python >= 3.8')
|
||||
def test_omegaconf(tmpdir):
|
||||
class OmegaConfModel(EvalModelTemplate):
|
||||
def __init__(self, ogc):
|
||||
super().__init__()
|
||||
self.ogc = ogc
|
||||
self.size = ogc.list[0]
|
||||
class SubClassEvalModel(EvalModelTemplate):
|
||||
any_other_loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]})
|
||||
model = OmegaConfModel(conf)
|
||||
def __init__(self, *args, subclass_arg=1200, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.save_hyperparameters()
|
||||
|
||||
# ensure ogc passed values correctly
|
||||
assert model.size == 15.4
|
||||
|
||||
class SubSubClassEvalModel(SubClassEvalModel):
|
||||
pass
|
||||
|
||||
|
||||
class AggSubClassEvalModel(SubClassEvalModel):
|
||||
|
||||
def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.save_hyperparameters()
|
||||
|
||||
|
||||
class UnconventionalArgsEvalModel(EvalModelTemplate):
|
||||
""" A model that has unconventional names for "self", "*args" and "**kwargs". """
|
||||
|
||||
def __init__(obj, *more_args, other_arg=300, **more_kwargs):
|
||||
# intentionally named obj
|
||||
super().__init__(*more_args, **more_kwargs)
|
||||
obj.save_hyperparameters()
|
||||
|
||||
|
||||
class DictConfSubClassEvalModel(SubClassEvalModel):
|
||||
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.save_hyperparameters()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [
|
||||
EvalModelTemplate,
|
||||
SubClassEvalModel,
|
||||
SubSubClassEvalModel,
|
||||
AggSubClassEvalModel,
|
||||
UnconventionalArgsEvalModel,
|
||||
DictConfSubClassEvalModel,
|
||||
])
|
||||
def test_collect_init_arguments(tmpdir, cls):
|
||||
""" Test that the model automatically saves the arguments passed into the constructor """
|
||||
extra_args = {}
|
||||
if cls is AggSubClassEvalModel:
|
||||
extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss())
|
||||
elif cls is DictConfSubClassEvalModel:
|
||||
extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything')))
|
||||
|
||||
model = cls(**extra_args)
|
||||
assert model.hparams.batch_size == 32
|
||||
model = cls(batch_size=179, **extra_args)
|
||||
assert model.hparams.batch_size == 179
|
||||
|
||||
if isinstance(model, SubClassEvalModel):
|
||||
assert model.hparams.subclass_arg == 1200
|
||||
|
||||
if isinstance(model, AggSubClassEvalModel):
|
||||
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
|
||||
|
||||
# verify that the checkpoint saved the correct values
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
|
||||
result = trainer.fit(model)
|
||||
trainer.fit(model)
|
||||
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
||||
|
||||
assert result == 1
|
||||
raw_checkpoint = torch.load(raw_checkpoint_path)
|
||||
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
|
||||
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179
|
||||
|
||||
# verify that model loads correctly
|
||||
# TODO: uncomment and get it pass
|
||||
# model = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
# assert model.hparams.batch_size == 179
|
||||
#
|
||||
# if isinstance(model, AggSubClassEvalModel):
|
||||
# assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
|
||||
#
|
||||
# if isinstance(model, DictConfSubClassEvalModel):
|
||||
# assert isinstance(model.hparams.dict_conf, Container)
|
||||
# assert model.hparams.dict_conf == 'anything'
|
||||
|
||||
# class SubClassEvalModel(EvalModelTemplate):
|
||||
# any_other_loss = torch.nn.CrossEntropyLoss()
|
||||
#
|
||||
# def __init__(self, *args, subclass_arg=1200, **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
# self.save_hyperparameters()
|
||||
#
|
||||
#
|
||||
# class SubSubClassEvalModel(SubClassEvalModel):
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# class AggSubClassEvalModel(SubClassEvalModel):
|
||||
#
|
||||
# def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
# self.save_hyperparameters()
|
||||
#
|
||||
#
|
||||
# class UnconventionalArgsEvalModel(EvalModelTemplate):
|
||||
# """ A model that has unconventional names for "self", "*args" and "**kwargs". """
|
||||
#
|
||||
# def __init__(obj, *more_args, other_arg=300, **more_kwargs):
|
||||
# # intentionally named obj
|
||||
# super().__init__(*more_args, **more_kwargs)
|
||||
# obj.save_hyperparameters()
|
||||
#
|
||||
#
|
||||
# class DictConfSubClassEvalModel(SubClassEvalModel):
|
||||
# def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
# self.save_hyperparameters()
|
||||
#
|
||||
#
|
||||
# @pytest.mark.parametrize("cls", [
|
||||
# EvalModelTemplate,
|
||||
# SubClassEvalModel,
|
||||
# SubSubClassEvalModel,
|
||||
# AggSubClassEvalModel,
|
||||
# UnconventionalArgsEvalModel,
|
||||
# DictConfSubClassEvalModel,
|
||||
# ])
|
||||
# def test_collect_init_arguments(tmpdir, cls):
|
||||
# """ Test that the model automatically saves the arguments passed into the constructor """
|
||||
# extra_args = {}
|
||||
# if cls is AggSubClassEvalModel:
|
||||
# extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss())
|
||||
# elif cls is DictConfSubClassEvalModel:
|
||||
# extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything')))
|
||||
#
|
||||
# model = cls(**extra_args)
|
||||
# assert model.batch_size == 32
|
||||
# model = cls(batch_size=179, **extra_args)
|
||||
# assert model.batch_size == 179
|
||||
#
|
||||
# if isinstance(model, SubClassEvalModel):
|
||||
# assert model.subclass_arg == 1200
|
||||
#
|
||||
# if isinstance(model, AggSubClassEvalModel):
|
||||
# assert isinstance(model.my_loss, torch.nn.CosineEmbeddingLoss)
|
||||
#
|
||||
# # verify that the checkpoint saved the correct values
|
||||
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
|
||||
# trainer.fit(model)
|
||||
# raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
||||
#
|
||||
# raw_checkpoint = torch.load(raw_checkpoint_path)
|
||||
# assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint
|
||||
# assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['batch_size'] == 179
|
||||
#
|
||||
# # verify that model loads correctly
|
||||
# model = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
# assert model.batch_size == 179
|
||||
#
|
||||
# if isinstance(model, AggSubClassEvalModel):
|
||||
# assert isinstance(model.my_loss, torch.nn.CrossEntropyLoss)
|
||||
#
|
||||
# if isinstance(model, DictConfSubClassEvalModel):
|
||||
# assert isinstance(model.dict_conf, DictConfig)
|
||||
# assert model.dict_conf == 'anything'
|
||||
#
|
||||
# # verify that we can overwrite whatever we want
|
||||
# model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
|
||||
# assert model.batch_size == 99
|
||||
# verify that we can overwrite whatever we want
|
||||
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
|
||||
assert model.hparams.batch_size == 99
|
||||
|
||||
|
||||
def _raw_checkpoint_path(trainer) -> str:
|
||||
|
@ -394,6 +380,28 @@ def test_single_config_models_fail(tmpdir, cls, config):
|
|||
_ = cls(**config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("past_key", ['module_arguments'])
|
||||
def test_load_past_checkpoint(tmpdir, past_key):
|
||||
model = EvalModelTemplate()
|
||||
|
||||
# verify we can train
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure the raw checkpoint saved the properties
|
||||
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
||||
raw_checkpoint = torch.load(raw_checkpoint_path)
|
||||
raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
raw_checkpoint[past_key]['batch_size'] = -17
|
||||
del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
# save back the checkpoint
|
||||
torch.save(raw_checkpoint, raw_checkpoint_path)
|
||||
|
||||
# verify that model loads correctly
|
||||
model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path)
|
||||
assert model2.hparams.batch_size == -17
|
||||
|
||||
|
||||
def test_hparams_pickle(tmpdir):
|
||||
ad = AttributeDict({'key1': 1, 'key2': 'abc'})
|
||||
pkl = pickle.dumps(ad)
|
||||
|
|
|
@ -48,7 +48,7 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
|||
|
||||
# assert ckpt has hparams
|
||||
ckpt = torch.load(new_weights_path)
|
||||
assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in ckpt.keys(), 'module_arguments missing from checkpoints'
|
||||
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), 'module_arguments missing from checkpoints'
|
||||
|
||||
# load new model
|
||||
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
|
|
Loading…
Reference in New Issue