From 35fe2efe270d21059727fefa5df149d99e4ce33c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 13 May 2020 04:27:22 -0400 Subject: [PATCH] added override for hparams in load_from_ckpt (#1797) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added override for hparams in load_from_ckpt * override hparams * override hparams * Apply suggestions from code review Co-authored-by: Adrian Wälchli * update doctest * typo * chlog Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli Co-authored-by: Jirka --- CHANGELOG.md | 2 ++ pytorch_lightning/core/lightning.py | 16 +++++++++++++-- pytorch_lightning/core/saving.py | 31 +++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce4c8dbae5..f2fcf48e25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added option to provide seed to random generators to ensure reproducibility ([#1572](https://github.com/PyTorchLightning/pytorch-lightning/pull/1572)) +- Added override for hparams in `load_from_ckpt` ([#1797](https://github.com/PyTorchLightning/pytorch-lightning/pull/1797)) + ### Changed - Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fac41bda15..d9c2e18620 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -16,8 +16,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import ModelHooks from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams from pytorch_lightning.core.properties import DeviceDtypeModuleMixin -from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn @@ -1439,6 +1439,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod checkpoint_path: str, map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, tags_csv: Optional[str] = None, + hparam_overrides: Optional[Dict] = None, *args, **kwargs ) -> 'LightningModule': r""" @@ -1480,6 +1481,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod use this method to pass in a .csv file with the hparams you'd like to use. These will be converted into a :class:`~argparse.Namespace` and passed into your :class:`LightningModule` for use. + hparam_overrides: A dictionary with keys to override in the hparams Return: :class:`LightningModule` with loaded weights and hyperparameters (if available). @@ -1503,6 +1505,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod tags_csv='/path/to/hparams_file.csv' ) + # override some of the params with new values + MyLightningModule.load_from_checkpoint( + PATH, + hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH} + ) + # or load passing whatever args the model takes to load MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', @@ -1521,12 +1529,16 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + # add the hparams from csv file to checkpoint if tags_csv is not None: - # add the hparams from csv file to checkpoint hparams = load_hparams_from_tags_csv(tags_csv) hparams.__setattr__('on_gpu', False) checkpoint['hparams'] = vars(hparams) + # override the hparam keys that were passed in + if hparam_overrides is not None: + update_hparams(hparams, hparam_overrides) + model = cls._load_model_state(checkpoint, *args, **kwargs) return model diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index d5e8d2a600..7a28a40d4e 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -48,6 +48,37 @@ class ModelIO(object): """ +def update_hparams(hparams: dict, updates: dict) -> None: + """ + Overrides hparams with new values + + >>> hparams = {'c': 4} + >>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1}) + >>> hparams['a']['b'], hparams['c'] + (2, 1) + >>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7}) + >>> hparams['a']['b'], hparams['c'] + (4, 7) + + Args: + hparams: the original params and also target object + updates: new params to be used as update + + """ + for k, v in updates.items(): + # if missing, add the key + if k not in hparams: + hparams[k] = v + continue + + # recurse if dictionary + if isinstance(v, dict): + update_hparams(hparams[k], updates[k]) + else: + # update the value + hparams.update({k: v}) + + def load_hparams_from_tags_csv(tags_csv: str) -> Namespace: if not os.path.isfile(tags_csv): log.warning(f'Missing Tags: {tags_csv}.')