warn user when dropping unpicklable hparams (#2874)
* refactored clean_namespace * Update try except to handle pickling error * Consolidated clean_namespace. Added is_picklable * PEP8 * Change warning to use rank_zero_warn. Added Test to ensure proper hparam filtering * Updated imports * Corrected Test Case
This commit is contained in:
parent
85cd558a3f
commit
d5254ff9df
|
@ -6,7 +6,7 @@ import torch
|
||||||
|
|
||||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
|
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
|
||||||
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict
|
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
|
@ -13,9 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import pickle
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from pytorch_lightning.utilities import rank_zero_warn
|
||||||
|
|
||||||
|
|
||||||
def str_to_bool(val):
|
def str_to_bool(val):
|
||||||
"""Convert a string representation of truth to true (1) or false (0).
|
"""Convert a string representation of truth to true (1) or false (0).
|
||||||
|
@ -39,26 +42,28 @@ def str_to_bool(val):
|
||||||
raise ValueError(f'invalid truth value {val}')
|
raise ValueError(f'invalid truth value {val}')
|
||||||
|
|
||||||
|
|
||||||
|
def is_picklable(obj: object) -> bool:
|
||||||
|
"""Tests if an object can be pickled"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
pickle.dumps(obj)
|
||||||
|
return True
|
||||||
|
except pickle.PicklingError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def clean_namespace(hparams):
|
def clean_namespace(hparams):
|
||||||
"""Removes all functions from hparams so we can pickle."""
|
"""Removes all unpicklable entries from hparams"""
|
||||||
|
|
||||||
|
hparams_dict = hparams
|
||||||
if isinstance(hparams, Namespace):
|
if isinstance(hparams, Namespace):
|
||||||
del_attrs = []
|
hparams_dict = hparams.__dict__
|
||||||
for k in hparams.__dict__:
|
|
||||||
if callable(getattr(hparams, k)):
|
del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]
|
||||||
del_attrs.append(k)
|
|
||||||
|
|
||||||
for k in del_attrs:
|
for k in del_attrs:
|
||||||
delattr(hparams, k)
|
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning)
|
||||||
|
del hparams_dict[k]
|
||||||
elif isinstance(hparams, dict):
|
|
||||||
del_attrs = []
|
|
||||||
for k, v in hparams.items():
|
|
||||||
if callable(v):
|
|
||||||
del_attrs.append(k)
|
|
||||||
|
|
||||||
for k in del_attrs:
|
|
||||||
del hparams[k]
|
|
||||||
|
|
||||||
|
|
||||||
def get_init_args(frame) -> dict:
|
def get_init_args(frame) -> dict:
|
||||||
|
|
|
@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from pytorch_lightning import Trainer, LightningModule
|
from pytorch_lightning import Trainer, LightningModule
|
||||||
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
|
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
|
||||||
from pytorch_lightning.utilities import AttributeDict
|
from pytorch_lightning.utilities import AttributeDict, is_picklable
|
||||||
from tests.base import EvalModelTemplate, TrialMNIST
|
from tests.base import EvalModelTemplate, TrialMNIST
|
||||||
|
|
||||||
|
|
||||||
|
@ -282,7 +282,7 @@ def test_collect_init_arguments(tmpdir, cls):
|
||||||
assert model.hparams.batch_size == 179
|
assert model.hparams.batch_size == 179
|
||||||
|
|
||||||
if isinstance(model, AggSubClassEvalModel):
|
if isinstance(model, AggSubClassEvalModel):
|
||||||
assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
|
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
|
||||||
|
|
||||||
if isinstance(model, DictConfSubClassEvalModel):
|
if isinstance(model, DictConfSubClassEvalModel):
|
||||||
assert isinstance(model.hparams.dict_conf, Container)
|
assert isinstance(model.hparams.dict_conf, Container)
|
||||||
|
@ -413,6 +413,23 @@ def test_hparams_pickle(tmpdir):
|
||||||
assert ad == pickle.loads(pkl)
|
assert ad == pickle.loads(pkl)
|
||||||
|
|
||||||
|
|
||||||
|
class UnpickleableArgsEvalModel(EvalModelTemplate):
|
||||||
|
""" A model that has an attribute that cannot be pickled. """
|
||||||
|
|
||||||
|
def __init__(self, foo='bar', pickle_me=(lambda x: x + 1), **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
assert not is_picklable(pickle_me)
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
|
||||||
|
def test_hparams_pickle_warning(tmpdir):
|
||||||
|
model = UnpickleableArgsEvalModel()
|
||||||
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
|
||||||
|
with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"):
|
||||||
|
trainer.fit(model)
|
||||||
|
assert 'pickle_me' not in model.hparams
|
||||||
|
|
||||||
|
|
||||||
def test_hparams_save_yaml(tmpdir):
|
def test_hparams_save_yaml(tmpdir):
|
||||||
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
|
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
|
||||||
nasted=dict(any_num=123, anystr='abcd'))
|
nasted=dict(any_num=123, anystr='abcd'))
|
||||||
|
|
Loading…
Reference in New Issue