added override for hparams in load_from_ckpt (#1797)

* added override for hparams in load_from_ckpt

* override hparams

* override hparams

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* update doctest

* typo

* chlog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
William Falcon 2020-05-13 04:27:22 -04:00 committed by GitHub
parent 10ce1c0256
commit 35fe2efe27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 2 deletions

View File

@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added option to provide seed to random generators to ensure reproducibility ([#1572](https://github.com/PyTorchLightning/pytorch-lightning/pull/1572))
- Added override for hparams in `load_from_ckpt` ([#1797](https://github.com/PyTorchLightning/pytorch-lightning/pull/1797))
### Changed
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))

View File

@ -16,8 +16,8 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
@ -1439,6 +1439,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
hparam_overrides: Optional[Dict] = None,
*args, **kwargs
) -> 'LightningModule':
r"""
@ -1480,6 +1481,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a :class:`~argparse.Namespace` and passed into your
:class:`LightningModule` for use.
hparam_overrides: A dictionary with keys to override in the hparams
Return:
:class:`LightningModule` with loaded weights and hyperparameters (if available).
@ -1503,6 +1505,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
tags_csv='/path/to/hparams_file.csv'
)
# override some of the params with new values
MyLightningModule.load_from_checkpoint(
PATH,
hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH}
)
# or load passing whatever args the model takes to load
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
@ -1521,12 +1529,16 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# add the hparams from csv file to checkpoint
if tags_csv is not None:
# add the hparams from csv file to checkpoint
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)
# override the hparam keys that were passed in
if hparam_overrides is not None:
update_hparams(hparams, hparam_overrides)
model = cls._load_model_state(checkpoint, *args, **kwargs)
return model

View File

@ -48,6 +48,37 @@ class ModelIO(object):
"""
def update_hparams(hparams: dict, updates: dict) -> None:
"""
Overrides hparams with new values
>>> hparams = {'c': 4}
>>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
>>> hparams['a']['b'], hparams['c']
(2, 1)
>>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
>>> hparams['a']['b'], hparams['c']
(4, 7)
Args:
hparams: the original params and also target object
updates: new params to be used as update
"""
for k, v in updates.items():
# if missing, add the key
if k not in hparams:
hparams[k] = v
continue
# recurse if dictionary
if isinstance(v, dict):
update_hparams(hparams[k], updates[k])
else:
# update the value
hparams.update({k: v})
def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')