Replace `_DataModuleWrapper` with `__new__` [1/2] (#7289)

* Remove `_DataModuleWrapper`

* Update pytorch_lightning/core/datamodule.py

* Update pytorch_lightning/core/datamodule.py

* Replace `__reduce__` with `__getstate__`
This commit is contained in:
Carlos Mocholí 2021-05-04 10:00:24 +02:00 committed by GitHub
parent 597b309f2e
commit 3fdb61ac1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 80 deletions

View File

@ -24,80 +24,7 @@ from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
class _DataModuleWrapper(type):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__has_added_checks = False
def __call__(cls, *args, **kwargs):
"""A wrapper for LightningDataModule that:
1. Runs user defined subclass's __init__
2. Assures prepare_data() runs on rank 0
3. Lets you check prepare_data and setup to see if they've been called
"""
if not cls.__has_added_checks:
cls.__has_added_checks = True
# Track prepare_data calls and make sure it runs on rank zero
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
# Track setup calls
cls.setup = track_data_hook_calls(cls.setup)
# Track teardown calls
cls.teardown = track_data_hook_calls(cls.teardown)
# Get instance of LightningDataModule by mocking its __init__ via __call__
obj = type.__call__(cls, *args, **kwargs)
return obj
def track_data_hook_calls(fn):
"""A decorator that checks if prepare_data/setup/teardown has been called.
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
Args:
fn (function): Function that will be tracked to see if it has been called.
Returns:
function: Decorated function that tracks its call status and saves it to private attrs in its obj instance.
"""
@functools.wraps(fn)
def wrapped_fn(*args, **kwargs):
# The object instance from which setup or prepare_data was called
obj = args[0]
name = fn.__name__
# If calling setup, we check the stage and assign stage-specific bool args
if name in ("setup", "teardown"):
# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)
if stage is None:
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_{name}_{s}", True)
else:
setattr(obj, f"_has_{name}_{stage}", True)
elif name == "prepare_data":
obj._has_prepared_data = True
return fn(*args, **kwargs)
return wrapped_fn
class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper):
class LightningDataModule(CheckpointHooks, DataHooks):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
@ -398,3 +325,62 @@ class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapp
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader
return datamodule
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
obj = super().__new__(cls)
# track `DataHooks` calls and run `prepare_data` only on rank zero
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
return obj
@staticmethod
def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable:
"""A decorator that checks if prepare_data/setup/teardown has been called.
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
Args:
obj: Object whose function will be tracked
fn: Function that will be tracked to see if it has been called.
Returns:
Decorated function that tracks its call status and saves it to private attrs in its obj instance.
"""
@functools.wraps(fn)
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
name = fn.__name__
# If calling setup, we check the stage and assign stage-specific bool args
if name in ("setup", "teardown"):
# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call
# setup('test') on trainer.test()
stage = args[0] if len(args) else kwargs.get("stage", None)
if stage is None:
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_{name}_{s}", True)
else:
setattr(obj, f"_has_{name}_{stage}", True)
elif name == "prepare_data":
obj._has_prepared_data = True
return fn(*args, **kwargs)
return wrapped_fn
def __getstate__(self) -> dict:
# avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...>
d = self.__dict__.copy()
for fn in ("prepare_data", "setup", "teardown"):
del d[fn]
return d

View File

@ -91,7 +91,7 @@ def test_can_prepare_data(local_rank, node_rank):
assert trainer.data_connector.can_prepare_data()
def test_hooks_no_recursion_error(tmpdir):
def test_hooks_no_recursion_error():
# hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error.
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652
class DummyDM(LightningDataModule):
@ -108,20 +108,20 @@ def test_hooks_no_recursion_error(tmpdir):
dm.prepare_data()
def test_helper_boringdatamodule(tmpdir):
def test_helper_boringdatamodule():
dm = BoringDataModule()
dm.prepare_data()
dm.setup()
def test_helper_boringdatamodule_with_verbose_setup(tmpdir):
def test_helper_boringdatamodule_with_verbose_setup():
dm = BoringDataModule()
dm.prepare_data()
dm.setup('fit')
dm.setup('test')
def test_data_hooks_called(tmpdir):
def test_data_hooks_called():
dm = BoringDataModule()
assert not dm.has_prepared_data
assert not dm.has_setup_fit
@ -168,7 +168,7 @@ def test_data_hooks_called(tmpdir):
@pytest.mark.parametrize("use_kwarg", (False, True))
def test_data_hooks_called_verbose(tmpdir, use_kwarg):
def test_data_hooks_called_verbose(use_kwarg):
dm = BoringDataModule()
dm.prepare_data()
assert not dm.has_setup_fit
@ -246,7 +246,7 @@ def test_dm_init_from_argparse_args(tmpdir):
assert dm.data_dir == args.data_dir == str(tmpdir)
def test_dm_pickle_after_init(tmpdir):
def test_dm_pickle_after_init():
dm = BoringDataModule()
pickle.dumps(dm)