past checkpoints (#2160)

* past checkpoints

* omegaConf save

* enforce type

* resolve=True

Co-authored-by: Omry Yadan <omry@fb.com>

* test omegaconf

* tests

* test past

Co-authored-by: Omry Yadan <omry@fb.com>
This commit is contained in:
Jirka Borovec 2020-06-14 17:36:45 +02:00 committed by GitHub
parent c826a5f599
commit c0903b800d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 175 additions and 133 deletions

View File

@ -21,10 +21,17 @@ except ImportError:
else:
ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, )
# the older shall be on the top
CHECKPOINT_PAST_HPARAMS_KEYS = (
'hparams',
'module_arguments', # used in 0.7.6
)
class ModelIO(object):
CHECKPOINT_KEY_HYPER_PARAMS = 'hyper_parameters'
CHECKPOINT_NAME_HYPER_PARAMS = 'hparams_name'
CHECKPOINT_HYPER_PARAMS_KEY = 'hyper_parameters'
CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name'
CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type'
@classmethod
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
@ -153,10 +160,13 @@ class ModelIO(object):
hparams['on_gpu'] = False
# overwrite hparams by the given file
checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS] = hparams
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
# override the module_arguments with values that were passed in
checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS].update(kwargs)
# for past checkpoint need to add the new key
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
# override the hparams with values that were passed in
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
model = cls._load_model_state(checkpoint, *args, **kwargs)
return model
@ -164,10 +174,15 @@ class ModelIO(object):
@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
# pass in the values we saved automatically
if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
# todo add some back compatibility
model_args = checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS]
args_name = checkpoint.get(cls.CHECKPOINT_NAME_HYPER_PARAMS)
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
model_args = {}
# add some back compatibility, the actual one shall be last
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
if hparam_key in checkpoint:
model_args.update(checkpoint[hparam_key])
if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
init_args_name = inspect.signature(cls).parameters.keys()
if args_name == 'kwargs':
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}

View File

@ -4,7 +4,6 @@ TensorBoard
"""
import os
import yaml
from argparse import Namespace
from typing import Optional, Dict, Union, Any
from warnings import warn
@ -18,6 +17,11 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
try:
from omegaconf import Container
except ImportError:
Container = None
class TensorBoardLogger(LightningLoggerBase):
r"""
@ -152,7 +156,14 @@ class TensorBoardLogger(LightningLoggerBase):
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
# save the metatags file
save_hparams_to_yaml(hparams_file, self.hparams)
if Container is not None:
if isinstance(self.hparams, Container):
from omegaconf import OmegaConf
OmegaConf.save(self.hparams, hparams_file, resolve=True)
else:
save_hparams_to_yaml(hparams_file, self.hparams)
else:
save_hparams_to_yaml(hparams_file, self.hparams)
@rank_zero_only
def finalize(self, status: str) -> None:

View File

@ -87,12 +87,12 @@ import os
import re
import signal
from abc import ABC
from argparse import Namespace
from subprocess import call
import torch
import torch.distributed as torch_distrib
import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
@ -100,7 +100,7 @@ from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import rank_zero_warn, parsing
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.io import load as pl_load
try:
@ -119,6 +119,11 @@ except ImportError:
else:
HOROVOD_AVAILABLE = True
try:
from omegaconf import Container
except ImportError:
Container = None
class TrainerIOMixin(ABC):
@ -267,8 +272,8 @@ class TrainerIOMixin(ABC):
try:
self._atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.'
f' An attribute is not picklable {err}')
self._atomic_save(checkpoint, filepath)
@ -320,6 +325,7 @@ class TrainerIOMixin(ABC):
checkpoint = {
'epoch': self.current_epoch + 1,
'global_step': self.global_step + 1,
'pytorch-ligthning_version': pytorch_lightning.__version__,
}
if not weights_only:
@ -356,10 +362,12 @@ class TrainerIOMixin(ABC):
if model.hparams:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_NAME_HYPER_PARAMS] = model._hparams_name
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# add arguments to the checkpoint
# todo: add some recursion in case of OmegaConf
checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS] = dict(model.hparams)
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if Container is not None:
if isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
# give the model a chance to add a few things
model.on_save_checkpoint(checkpoint)
@ -473,8 +481,8 @@ class TrainerIOMixin(ABC):
try:
self._atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn('warning, `module_arguments` dropped from checkpoint.'
f' An attribute is not picklable {err}')
self._atomic_save(checkpoint, filepath)

View File

@ -30,31 +30,34 @@ class AssignHparamsModel(EvalModelTemplate):
# -------------------------
# STANDARD TESTS
# -------------------------
def _run_standard_hparams_test(tmpdir, model, cls):
def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
"""
Tests for the existence of an arg 'test_arg=14'
"""
hparam_type = type(model.hparams)
# test proper property assignments
assert model.hparams.test_arg == 14
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_pct=0.5)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['test_arg'] == 14
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14
# verify that model loads correctly
model = cls.load_from_checkpoint(raw_checkpoint_path)
assert model.hparams.test_arg == 14
model2 = cls.load_from_checkpoint(raw_checkpoint_path)
assert model2.hparams.test_arg == 14
# todo
# verify that we can overwrite the property
# model = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
# assert model.hparams.test_arg == 78
assert isinstance(model2.hparams, hparam_type)
if try_overwrite:
# verify that we can overwrite the property
model3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
assert model3.hparams.test_arg == 78
return raw_checkpoint_path
@ -82,14 +85,16 @@ def test_omega_conf_hparams(tmpdir, cls):
# init model
conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)]))
model = cls(hparams=conf)
assert isinstance(model.hparams, Container)
# run standard test suite
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls)
model = cls.load_from_checkpoint(raw_checkpoint_path)
model2 = cls.load_from_checkpoint(raw_checkpoint_path)
assert isinstance(model2.hparams, Container)
# config specific tests
assert model.hparams.test_arg == 14
assert model.hparams.mylist[0] == 15.4
assert model2.hparams.test_arg == 14
assert model2.hparams.mylist[0] == 15.4
def test_explicit_args_hparams(tmpdir):
@ -157,8 +162,8 @@ def test_explicit_missing_args_hparams(tmpdir):
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['test_arg'] == 14
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14
# verify that model loads correctly
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
@ -199,110 +204,91 @@ def test_class_nesting():
A().test()
@pytest.mark.xfail(sys.version_info >= (3, 6), reason='OmegaConf only for Python >= 3.8')
def test_omegaconf(tmpdir):
class OmegaConfModel(EvalModelTemplate):
def __init__(self, ogc):
super().__init__()
self.ogc = ogc
self.size = ogc.list[0]
class SubClassEvalModel(EvalModelTemplate):
any_other_loss = torch.nn.CrossEntropyLoss()
conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]})
model = OmegaConfModel(conf)
def __init__(self, *args, subclass_arg=1200, **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
# ensure ogc passed values correctly
assert model.size == 15.4
class SubSubClassEvalModel(SubClassEvalModel):
pass
class AggSubClassEvalModel(SubClassEvalModel):
def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
class UnconventionalArgsEvalModel(EvalModelTemplate):
""" A model that has unconventional names for "self", "*args" and "**kwargs". """
def __init__(obj, *more_args, other_arg=300, **more_kwargs):
# intentionally named obj
super().__init__(*more_args, **more_kwargs)
obj.save_hyperparameters()
class DictConfSubClassEvalModel(SubClassEvalModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
@pytest.mark.parametrize("cls", [
EvalModelTemplate,
SubClassEvalModel,
SubSubClassEvalModel,
AggSubClassEvalModel,
UnconventionalArgsEvalModel,
DictConfSubClassEvalModel,
])
def test_collect_init_arguments(tmpdir, cls):
""" Test that the model automatically saves the arguments passed into the constructor """
extra_args = {}
if cls is AggSubClassEvalModel:
extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss())
elif cls is DictConfSubClassEvalModel:
extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything')))
model = cls(**extra_args)
assert model.hparams.batch_size == 32
model = cls(batch_size=179, **extra_args)
assert model.hparams.batch_size == 179
if isinstance(model, SubClassEvalModel):
assert model.hparams.subclass_arg == 1200
if isinstance(model, AggSubClassEvalModel):
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
result = trainer.fit(model)
trainer.fit(model)
raw_checkpoint_path = _raw_checkpoint_path(trainer)
assert result == 1
raw_checkpoint = torch.load(raw_checkpoint_path)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['batch_size'] == 179
# verify that model loads correctly
# TODO: uncomment and get it pass
# model = cls.load_from_checkpoint(raw_checkpoint_path)
# assert model.hparams.batch_size == 179
#
# if isinstance(model, AggSubClassEvalModel):
# assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
#
# if isinstance(model, DictConfSubClassEvalModel):
# assert isinstance(model.hparams.dict_conf, Container)
# assert model.hparams.dict_conf == 'anything'
# class SubClassEvalModel(EvalModelTemplate):
# any_other_loss = torch.nn.CrossEntropyLoss()
#
# def __init__(self, *args, subclass_arg=1200, **kwargs):
# super().__init__(*args, **kwargs)
# self.save_hyperparameters()
#
#
# class SubSubClassEvalModel(SubClassEvalModel):
# pass
#
#
# class AggSubClassEvalModel(SubClassEvalModel):
#
# def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
# super().__init__(*args, **kwargs)
# self.save_hyperparameters()
#
#
# class UnconventionalArgsEvalModel(EvalModelTemplate):
# """ A model that has unconventional names for "self", "*args" and "**kwargs". """
#
# def __init__(obj, *more_args, other_arg=300, **more_kwargs):
# # intentionally named obj
# super().__init__(*more_args, **more_kwargs)
# obj.save_hyperparameters()
#
#
# class DictConfSubClassEvalModel(SubClassEvalModel):
# def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs):
# super().__init__(*args, **kwargs)
# self.save_hyperparameters()
#
#
# @pytest.mark.parametrize("cls", [
# EvalModelTemplate,
# SubClassEvalModel,
# SubSubClassEvalModel,
# AggSubClassEvalModel,
# UnconventionalArgsEvalModel,
# DictConfSubClassEvalModel,
# ])
# def test_collect_init_arguments(tmpdir, cls):
# """ Test that the model automatically saves the arguments passed into the constructor """
# extra_args = {}
# if cls is AggSubClassEvalModel:
# extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss())
# elif cls is DictConfSubClassEvalModel:
# extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything')))
#
# model = cls(**extra_args)
# assert model.batch_size == 32
# model = cls(batch_size=179, **extra_args)
# assert model.batch_size == 179
#
# if isinstance(model, SubClassEvalModel):
# assert model.subclass_arg == 1200
#
# if isinstance(model, AggSubClassEvalModel):
# assert isinstance(model.my_loss, torch.nn.CosineEmbeddingLoss)
#
# # verify that the checkpoint saved the correct values
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
# trainer.fit(model)
# raw_checkpoint_path = _raw_checkpoint_path(trainer)
#
# raw_checkpoint = torch.load(raw_checkpoint_path)
# assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in raw_checkpoint
# assert raw_checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]['batch_size'] == 179
#
# # verify that model loads correctly
# model = cls.load_from_checkpoint(raw_checkpoint_path)
# assert model.batch_size == 179
#
# if isinstance(model, AggSubClassEvalModel):
# assert isinstance(model.my_loss, torch.nn.CrossEntropyLoss)
#
# if isinstance(model, DictConfSubClassEvalModel):
# assert isinstance(model.dict_conf, DictConfig)
# assert model.dict_conf == 'anything'
#
# # verify that we can overwrite whatever we want
# model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
# assert model.batch_size == 99
# verify that we can overwrite whatever we want
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
assert model.hparams.batch_size == 99
def _raw_checkpoint_path(trainer) -> str:
@ -394,6 +380,28 @@ def test_single_config_models_fail(tmpdir, cls, config):
_ = cls(**config)
@pytest.mark.parametrize("past_key", ['module_arguments'])
def test_load_past_checkpoint(tmpdir, past_key):
model = EvalModelTemplate()
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
raw_checkpoint[past_key]['batch_size'] = -17
del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
# save back the checkpoint
torch.save(raw_checkpoint, raw_checkpoint_path)
# verify that model loads correctly
model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path)
assert model2.hparams.batch_size == -17
def test_hparams_pickle(tmpdir):
ad = AttributeDict({'key1': 1, 'key2': 'abc'})
pkl = pickle.dumps(ad)

View File

@ -48,7 +48,7 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
# assert ckpt has hparams
ckpt = torch.load(new_weights_path)
assert LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in ckpt.keys(), 'module_arguments missing from checkpoints'
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), 'module_arguments missing from checkpoints'
# load new model
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)