warn user when dropping unpicklable hparams (#2874)

* refactored clean_namespace

* Update try except to handle pickling error

* Consolidated clean_namespace. Added is_picklable

* PEP8

* Change warning to use rank_zero_warn. Added Test to ensure proper hparam filtering

* Updated imports

* Corrected Test Case
This commit is contained in:
monney 2020-08-28 03:07:43 -04:00 committed by GitHub
parent 85cd558a3f
commit d5254ff9df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 18 deletions

View File

@ -6,7 +6,7 @@ import torch
from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
try: try:
from apex import amp from apex import amp

View File

@ -13,9 +13,12 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import pickle
from argparse import Namespace from argparse import Namespace
from typing import Dict from typing import Dict
from pytorch_lightning.utilities import rank_zero_warn
def str_to_bool(val): def str_to_bool(val):
"""Convert a string representation of truth to true (1) or false (0). """Convert a string representation of truth to true (1) or false (0).
@ -39,26 +42,28 @@ def str_to_bool(val):
raise ValueError(f'invalid truth value {val}') raise ValueError(f'invalid truth value {val}')
def is_picklable(obj: object) -> bool:
"""Tests if an object can be pickled"""
try:
pickle.dumps(obj)
return True
except pickle.PicklingError:
return False
def clean_namespace(hparams): def clean_namespace(hparams):
"""Removes all functions from hparams so we can pickle.""" """Removes all unpicklable entries from hparams"""
hparams_dict = hparams
if isinstance(hparams, Namespace): if isinstance(hparams, Namespace):
del_attrs = [] hparams_dict = hparams.__dict__
for k in hparams.__dict__:
if callable(getattr(hparams, k)): del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]
del_attrs.append(k)
for k in del_attrs: for k in del_attrs:
delattr(hparams, k) rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning)
del hparams_dict[k]
elif isinstance(hparams, dict):
del_attrs = []
for k, v in hparams.items():
if callable(v):
del_attrs.append(k)
for k in del_attrs:
del hparams[k]
def get_init_args(frame) -> dict: def get_init_args(frame) -> dict:

View File

@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
from pytorch_lightning import Trainer, LightningModule from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities import AttributeDict, is_picklable
from tests.base import EvalModelTemplate, TrialMNIST from tests.base import EvalModelTemplate, TrialMNIST
@ -282,7 +282,7 @@ def test_collect_init_arguments(tmpdir, cls):
assert model.hparams.batch_size == 179 assert model.hparams.batch_size == 179
if isinstance(model, AggSubClassEvalModel): if isinstance(model, AggSubClassEvalModel):
assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss) assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
if isinstance(model, DictConfSubClassEvalModel): if isinstance(model, DictConfSubClassEvalModel):
assert isinstance(model.hparams.dict_conf, Container) assert isinstance(model.hparams.dict_conf, Container)
@ -413,6 +413,23 @@ def test_hparams_pickle(tmpdir):
assert ad == pickle.loads(pkl) assert ad == pickle.loads(pkl)
class UnpickleableArgsEvalModel(EvalModelTemplate):
""" A model that has an attribute that cannot be pickled. """
def __init__(self, foo='bar', pickle_me=(lambda x: x + 1), **kwargs):
super().__init__(**kwargs)
assert not is_picklable(pickle_me)
self.save_hyperparameters()
def test_hparams_pickle_warning(tmpdir):
model = UnpickleableArgsEvalModel()
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"):
trainer.fit(model)
assert 'pickle_me' not in model.hparams
def test_hparams_save_yaml(tmpdir): def test_hparams_save_yaml(tmpdir):
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here', hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
nasted=dict(any_num=123, anystr='abcd')) nasted=dict(any_num=123, anystr='abcd'))