Replace `_DataModuleWrapper` with `__new__` [1/2] (#7289)
* Remove `_DataModuleWrapper` * Update pytorch_lightning/core/datamodule.py * Update pytorch_lightning/core/datamodule.py * Replace `__reduce__` with `__getstate__`
This commit is contained in:
parent
597b309f2e
commit
3fdb61ac1b
|
@ -24,80 +24,7 @@ from pytorch_lightning.utilities import rank_zero_only
|
|||
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
||||
|
||||
|
||||
class _DataModuleWrapper(type):
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__has_added_checks = False
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""A wrapper for LightningDataModule that:
|
||||
|
||||
1. Runs user defined subclass's __init__
|
||||
2. Assures prepare_data() runs on rank 0
|
||||
3. Lets you check prepare_data and setup to see if they've been called
|
||||
"""
|
||||
if not cls.__has_added_checks:
|
||||
cls.__has_added_checks = True
|
||||
# Track prepare_data calls and make sure it runs on rank zero
|
||||
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
|
||||
# Track setup calls
|
||||
cls.setup = track_data_hook_calls(cls.setup)
|
||||
# Track teardown calls
|
||||
cls.teardown = track_data_hook_calls(cls.teardown)
|
||||
|
||||
# Get instance of LightningDataModule by mocking its __init__ via __call__
|
||||
obj = type.__call__(cls, *args, **kwargs)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def track_data_hook_calls(fn):
|
||||
"""A decorator that checks if prepare_data/setup/teardown has been called.
|
||||
|
||||
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
|
||||
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
|
||||
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
|
||||
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
|
||||
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
|
||||
|
||||
Args:
|
||||
fn (function): Function that will be tracked to see if it has been called.
|
||||
|
||||
Returns:
|
||||
function: Decorated function that tracks its call status and saves it to private attrs in its obj instance.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
|
||||
# The object instance from which setup or prepare_data was called
|
||||
obj = args[0]
|
||||
name = fn.__name__
|
||||
|
||||
# If calling setup, we check the stage and assign stage-specific bool args
|
||||
if name in ("setup", "teardown"):
|
||||
|
||||
# Get stage either by grabbing from args or checking kwargs.
|
||||
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
|
||||
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
|
||||
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)
|
||||
|
||||
if stage is None:
|
||||
for s in ("fit", "validate", "test"):
|
||||
setattr(obj, f"_has_{name}_{s}", True)
|
||||
else:
|
||||
setattr(obj, f"_has_{name}_{stage}", True)
|
||||
|
||||
elif name == "prepare_data":
|
||||
obj._has_prepared_data = True
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper):
|
||||
class LightningDataModule(CheckpointHooks, DataHooks):
|
||||
"""
|
||||
A DataModule standardizes the training, val, test splits, data preparation and transforms.
|
||||
The main advantage is consistent data splits, data preparation and transforms across models.
|
||||
|
@ -398,3 +325,62 @@ class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapp
|
|||
if test_dataset is not None:
|
||||
datamodule.test_dataloader = test_dataloader
|
||||
return datamodule
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
|
||||
obj = super().__new__(cls)
|
||||
# track `DataHooks` calls and run `prepare_data` only on rank zero
|
||||
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
|
||||
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
|
||||
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable:
|
||||
"""A decorator that checks if prepare_data/setup/teardown has been called.
|
||||
|
||||
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
|
||||
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
|
||||
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
|
||||
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
|
||||
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
|
||||
|
||||
Args:
|
||||
obj: Object whose function will be tracked
|
||||
fn: Function that will be tracked to see if it has been called.
|
||||
|
||||
Returns:
|
||||
Decorated function that tracks its call status and saves it to private attrs in its obj instance.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
|
||||
name = fn.__name__
|
||||
|
||||
# If calling setup, we check the stage and assign stage-specific bool args
|
||||
if name in ("setup", "teardown"):
|
||||
|
||||
# Get stage either by grabbing from args or checking kwargs.
|
||||
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
|
||||
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call
|
||||
# setup('test') on trainer.test()
|
||||
stage = args[0] if len(args) else kwargs.get("stage", None)
|
||||
|
||||
if stage is None:
|
||||
for s in ("fit", "validate", "test"):
|
||||
setattr(obj, f"_has_{name}_{s}", True)
|
||||
else:
|
||||
setattr(obj, f"_has_{name}_{stage}", True)
|
||||
|
||||
elif name == "prepare_data":
|
||||
obj._has_prepared_data = True
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
# avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...>
|
||||
d = self.__dict__.copy()
|
||||
for fn in ("prepare_data", "setup", "teardown"):
|
||||
del d[fn]
|
||||
return d
|
||||
|
|
|
@ -91,7 +91,7 @@ def test_can_prepare_data(local_rank, node_rank):
|
|||
assert trainer.data_connector.can_prepare_data()
|
||||
|
||||
|
||||
def test_hooks_no_recursion_error(tmpdir):
|
||||
def test_hooks_no_recursion_error():
|
||||
# hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error.
|
||||
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652
|
||||
class DummyDM(LightningDataModule):
|
||||
|
@ -108,20 +108,20 @@ def test_hooks_no_recursion_error(tmpdir):
|
|||
dm.prepare_data()
|
||||
|
||||
|
||||
def test_helper_boringdatamodule(tmpdir):
|
||||
def test_helper_boringdatamodule():
|
||||
dm = BoringDataModule()
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
|
||||
|
||||
def test_helper_boringdatamodule_with_verbose_setup(tmpdir):
|
||||
def test_helper_boringdatamodule_with_verbose_setup():
|
||||
dm = BoringDataModule()
|
||||
dm.prepare_data()
|
||||
dm.setup('fit')
|
||||
dm.setup('test')
|
||||
|
||||
|
||||
def test_data_hooks_called(tmpdir):
|
||||
def test_data_hooks_called():
|
||||
dm = BoringDataModule()
|
||||
assert not dm.has_prepared_data
|
||||
assert not dm.has_setup_fit
|
||||
|
@ -168,7 +168,7 @@ def test_data_hooks_called(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("use_kwarg", (False, True))
|
||||
def test_data_hooks_called_verbose(tmpdir, use_kwarg):
|
||||
def test_data_hooks_called_verbose(use_kwarg):
|
||||
dm = BoringDataModule()
|
||||
dm.prepare_data()
|
||||
assert not dm.has_setup_fit
|
||||
|
@ -246,7 +246,7 @@ def test_dm_init_from_argparse_args(tmpdir):
|
|||
assert dm.data_dir == args.data_dir == str(tmpdir)
|
||||
|
||||
|
||||
def test_dm_pickle_after_init(tmpdir):
|
||||
def test_dm_pickle_after_init():
|
||||
dm = BoringDataModule()
|
||||
pickle.dumps(dm)
|
||||
|
||||
|
|
Loading…
Reference in New Issue