warn user when dropping unpicklable hparams (#2874)
* refactored clean_namespace * Update try except to handle pickling error * Consolidated clean_namespace. Added is_picklable * PEP8 * Change warning to use rank_zero_warn. Added Test to ensure proper hparam filtering * Updated imports * Corrected Test Case
This commit is contained in:
parent
85cd558a3f
commit
d5254ff9df
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -13,9 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import pickle
|
||||
from argparse import Namespace
|
||||
from typing import Dict
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
def str_to_bool(val):
|
||||
"""Convert a string representation of truth to true (1) or false (0).
|
||||
|
@ -39,26 +42,28 @@ def str_to_bool(val):
|
|||
raise ValueError(f'invalid truth value {val}')
|
||||
|
||||
|
||||
def is_picklable(obj: object) -> bool:
|
||||
"""Tests if an object can be pickled"""
|
||||
|
||||
try:
|
||||
pickle.dumps(obj)
|
||||
return True
|
||||
except pickle.PicklingError:
|
||||
return False
|
||||
|
||||
|
||||
def clean_namespace(hparams):
|
||||
"""Removes all functions from hparams so we can pickle."""
|
||||
"""Removes all unpicklable entries from hparams"""
|
||||
|
||||
hparams_dict = hparams
|
||||
if isinstance(hparams, Namespace):
|
||||
del_attrs = []
|
||||
for k in hparams.__dict__:
|
||||
if callable(getattr(hparams, k)):
|
||||
del_attrs.append(k)
|
||||
hparams_dict = hparams.__dict__
|
||||
|
||||
for k in del_attrs:
|
||||
delattr(hparams, k)
|
||||
del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]
|
||||
|
||||
elif isinstance(hparams, dict):
|
||||
del_attrs = []
|
||||
for k, v in hparams.items():
|
||||
if callable(v):
|
||||
del_attrs.append(k)
|
||||
|
||||
for k in del_attrs:
|
||||
del hparams[k]
|
||||
for k in del_attrs:
|
||||
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning)
|
||||
del hparams_dict[k]
|
||||
|
||||
|
||||
def get_init_args(frame) -> dict:
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
from pytorch_lightning.utilities import AttributeDict, is_picklable
|
||||
from tests.base import EvalModelTemplate, TrialMNIST
|
||||
|
||||
|
||||
|
@ -282,7 +282,7 @@ def test_collect_init_arguments(tmpdir, cls):
|
|||
assert model.hparams.batch_size == 179
|
||||
|
||||
if isinstance(model, AggSubClassEvalModel):
|
||||
assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
|
||||
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
|
||||
|
||||
if isinstance(model, DictConfSubClassEvalModel):
|
||||
assert isinstance(model.hparams.dict_conf, Container)
|
||||
|
@ -413,6 +413,23 @@ def test_hparams_pickle(tmpdir):
|
|||
assert ad == pickle.loads(pkl)
|
||||
|
||||
|
||||
class UnpickleableArgsEvalModel(EvalModelTemplate):
|
||||
""" A model that has an attribute that cannot be pickled. """
|
||||
|
||||
def __init__(self, foo='bar', pickle_me=(lambda x: x + 1), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert not is_picklable(pickle_me)
|
||||
self.save_hyperparameters()
|
||||
|
||||
|
||||
def test_hparams_pickle_warning(tmpdir):
|
||||
model = UnpickleableArgsEvalModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
|
||||
with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"):
|
||||
trainer.fit(model)
|
||||
assert 'pickle_me' not in model.hparams
|
||||
|
||||
|
||||
def test_hparams_save_yaml(tmpdir):
|
||||
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
|
||||
nasted=dict(any_num=123, anystr='abcd'))
|
||||
|
|
Loading…
Reference in New Issue