added override for hparams in load_from_ckpt (#1797)
* added override for hparams in load_from_ckpt * override hparams * override hparams * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * update doctest * typo * chlog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
10ce1c0256
commit
35fe2efe27
|
@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added option to provide seed to random generators to ensure reproducibility ([#1572](https://github.com/PyTorchLightning/pytorch-lightning/pull/1572))
|
||||
|
||||
- Added override for hparams in `load_from_ckpt` ([#1797](https://github.com/PyTorchLightning/pytorch-lightning/pull/1797))
|
||||
|
||||
### Changed
|
||||
|
||||
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
|
||||
|
|
|
@ -16,8 +16,8 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.core.grads import GradInformation
|
||||
from pytorch_lightning.core.hooks import ModelHooks
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
|
||||
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
@ -1439,6 +1439,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
checkpoint_path: str,
|
||||
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
|
||||
tags_csv: Optional[str] = None,
|
||||
hparam_overrides: Optional[Dict] = None,
|
||||
*args, **kwargs
|
||||
) -> 'LightningModule':
|
||||
r"""
|
||||
|
@ -1480,6 +1481,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
use this method to pass in a .csv file with the hparams you'd like to use.
|
||||
These will be converted into a :class:`~argparse.Namespace` and passed into your
|
||||
:class:`LightningModule` for use.
|
||||
hparam_overrides: A dictionary with keys to override in the hparams
|
||||
|
||||
Return:
|
||||
:class:`LightningModule` with loaded weights and hyperparameters (if available).
|
||||
|
@ -1503,6 +1505,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
tags_csv='/path/to/hparams_file.csv'
|
||||
)
|
||||
|
||||
# override some of the params with new values
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
PATH,
|
||||
hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH}
|
||||
)
|
||||
|
||||
# or load passing whatever args the model takes to load
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
'path/to/checkpoint.ckpt',
|
||||
|
@ -1521,12 +1529,16 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
# add the hparams from csv file to checkpoint
|
||||
if tags_csv is not None:
|
||||
# add the hparams from csv file to checkpoint
|
||||
hparams = load_hparams_from_tags_csv(tags_csv)
|
||||
hparams.__setattr__('on_gpu', False)
|
||||
checkpoint['hparams'] = vars(hparams)
|
||||
|
||||
# override the hparam keys that were passed in
|
||||
if hparam_overrides is not None:
|
||||
update_hparams(hparams, hparam_overrides)
|
||||
|
||||
model = cls._load_model_state(checkpoint, *args, **kwargs)
|
||||
return model
|
||||
|
||||
|
|
|
@ -48,6 +48,37 @@ class ModelIO(object):
|
|||
"""
|
||||
|
||||
|
||||
def update_hparams(hparams: dict, updates: dict) -> None:
|
||||
"""
|
||||
Overrides hparams with new values
|
||||
|
||||
>>> hparams = {'c': 4}
|
||||
>>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
|
||||
>>> hparams['a']['b'], hparams['c']
|
||||
(2, 1)
|
||||
>>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
|
||||
>>> hparams['a']['b'], hparams['c']
|
||||
(4, 7)
|
||||
|
||||
Args:
|
||||
hparams: the original params and also target object
|
||||
updates: new params to be used as update
|
||||
|
||||
"""
|
||||
for k, v in updates.items():
|
||||
# if missing, add the key
|
||||
if k not in hparams:
|
||||
hparams[k] = v
|
||||
continue
|
||||
|
||||
# recurse if dictionary
|
||||
if isinstance(v, dict):
|
||||
update_hparams(hparams[k], updates[k])
|
||||
else:
|
||||
# update the value
|
||||
hparams.update({k: v})
|
||||
|
||||
|
||||
def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
|
||||
if not os.path.isfile(tags_csv):
|
||||
log.warning(f'Missing Tags: {tags_csv}.')
|
||||
|
|
Loading…
Reference in New Issue