Use PickleError base class to detect all pickle errors (#6917)
* Use PickleError base class to detect all pickle errors * Update changelog with #6917 * Add pickle test for torch ScriptModule Co-authored-by: Ken Witham <k.witham@kri.neu.edu> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
parent
03a73b37bc
commit
dcff5036a8
|
@ -243,7 +243,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
|
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
|
||||||
|
|
||||||
|
|
||||||
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
|
- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
|
||||||
|
|
||||||
|
|
||||||
- Fixed multi-gpu join for Horovod ([#6954](https://github.com/PyTorchLightning/pytorch-lightning/pull/6954))
|
- Fixed multi-gpu join for Horovod ([#6954](https://github.com/PyTorchLightning/pytorch-lightning/pull/6954))
|
||||||
|
|
|
@ -61,7 +61,7 @@ def is_picklable(obj: object) -> bool:
|
||||||
try:
|
try:
|
||||||
pickle.dumps(obj)
|
pickle.dumps(obj)
|
||||||
return True
|
return True
|
||||||
except (pickle.PicklingError, AttributeError):
|
except (pickle.PickleError, AttributeError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,8 @@ import inspect
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from torch.jit import ScriptModule
|
||||||
|
|
||||||
from pytorch_lightning.utilities.parsing import (
|
from pytorch_lightning.utilities.parsing import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
clean_namespace,
|
clean_namespace,
|
||||||
|
@ -203,7 +205,7 @@ def test_is_picklable(tmpdir):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
true_cases = [None, True, 123, "str", (123, "str"), max]
|
true_cases = [None, True, 123, "str", (123, "str"), max]
|
||||||
false_cases = [unpicklable_function, UnpicklableClass]
|
false_cases = [unpicklable_function, UnpicklableClass, ScriptModule()]
|
||||||
|
|
||||||
for case in true_cases:
|
for case in true_cases:
|
||||||
assert is_picklable(case) is True
|
assert is_picklable(case) is True
|
||||||
|
|
Loading…
Reference in New Issue