diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 9550ceae4a..4eaed6c7b2 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -24,80 +24,7 @@ from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -class _DataModuleWrapper(type): - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.__has_added_checks = False - - def __call__(cls, *args, **kwargs): - """A wrapper for LightningDataModule that: - - 1. Runs user defined subclass's __init__ - 2. Assures prepare_data() runs on rank 0 - 3. Lets you check prepare_data and setup to see if they've been called - """ - if not cls.__has_added_checks: - cls.__has_added_checks = True - # Track prepare_data calls and make sure it runs on rank zero - cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) - # Track setup calls - cls.setup = track_data_hook_calls(cls.setup) - # Track teardown calls - cls.teardown = track_data_hook_calls(cls.teardown) - - # Get instance of LightningDataModule by mocking its __init__ via __call__ - obj = type.__call__(cls, *args, **kwargs) - - return obj - - -def track_data_hook_calls(fn): - """A decorator that checks if prepare_data/setup/teardown has been called. - - - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. - Its corresponding `dm_has_setup_{stage}` attribute gets set to True - - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` - - Args: - fn (function): Function that will be tracked to see if it has been called. - - Returns: - function: Decorated function that tracks its call status and saves it to private attrs in its obj instance. - """ - - @functools.wraps(fn) - def wrapped_fn(*args, **kwargs): - - # The object instance from which setup or prepare_data was called - obj = args[0] - name = fn.__name__ - - # If calling setup, we check the stage and assign stage-specific bool args - if name in ("setup", "teardown"): - - # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit', 'validate', and 'test' to True. - # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() - stage = args[1] if len(args) > 1 else kwargs.get("stage", None) - - if stage is None: - for s in ("fit", "validate", "test"): - setattr(obj, f"_has_{name}_{s}", True) - else: - setattr(obj, f"_has_{name}_{stage}", True) - - elif name == "prepare_data": - obj._has_prepared_data = True - - return fn(*args, **kwargs) - - return wrapped_fn - - -class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper): +class LightningDataModule(CheckpointHooks, DataHooks): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. @@ -398,3 +325,62 @@ class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapp if test_dataset is not None: datamodule.test_dataloader = test_dataloader return datamodule + + def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': + obj = super().__new__(cls) + # track `DataHooks` calls and run `prepare_data` only on rank zero + obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data)) + obj.setup = cls._track_data_hook_calls(obj, obj.setup) + obj.teardown = cls._track_data_hook_calls(obj, obj.teardown) + return obj + + @staticmethod + def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable: + """A decorator that checks if prepare_data/setup/teardown has been called. + + - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True + - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. + Its corresponding `dm_has_setup_{stage}` attribute gets set to True + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` + + Args: + obj: Object whose function will be tracked + fn: Function that will be tracked to see if it has been called. + + Returns: + Decorated function that tracks its call status and saves it to private attrs in its obj instance. + """ + + @functools.wraps(fn) + def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: + name = fn.__name__ + + # If calling setup, we check the stage and assign stage-specific bool args + if name in ("setup", "teardown"): + + # Get stage either by grabbing from args or checking kwargs. + # If not provided, set call status of 'fit', 'validate', and 'test' to True. + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call + # setup('test') on trainer.test() + stage = args[0] if len(args) else kwargs.get("stage", None) + + if stage is None: + for s in ("fit", "validate", "test"): + setattr(obj, f"_has_{name}_{s}", True) + else: + setattr(obj, f"_has_{name}_{stage}", True) + + elif name == "prepare_data": + obj._has_prepared_data = True + + return fn(*args, **kwargs) + + return wrapped_fn + + def __getstate__(self) -> dict: + # avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...> + d = self.__dict__.copy() + for fn in ("prepare_data", "setup", "teardown"): + del d[fn] + return d diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 40c38b9d3a..c4eb076e04 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -91,7 +91,7 @@ def test_can_prepare_data(local_rank, node_rank): assert trainer.data_connector.can_prepare_data() -def test_hooks_no_recursion_error(tmpdir): +def test_hooks_no_recursion_error(): # hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error. # See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652 class DummyDM(LightningDataModule): @@ -108,20 +108,20 @@ def test_hooks_no_recursion_error(tmpdir): dm.prepare_data() -def test_helper_boringdatamodule(tmpdir): +def test_helper_boringdatamodule(): dm = BoringDataModule() dm.prepare_data() dm.setup() -def test_helper_boringdatamodule_with_verbose_setup(tmpdir): +def test_helper_boringdatamodule_with_verbose_setup(): dm = BoringDataModule() dm.prepare_data() dm.setup('fit') dm.setup('test') -def test_data_hooks_called(tmpdir): +def test_data_hooks_called(): dm = BoringDataModule() assert not dm.has_prepared_data assert not dm.has_setup_fit @@ -168,7 +168,7 @@ def test_data_hooks_called(tmpdir): @pytest.mark.parametrize("use_kwarg", (False, True)) -def test_data_hooks_called_verbose(tmpdir, use_kwarg): +def test_data_hooks_called_verbose(use_kwarg): dm = BoringDataModule() dm.prepare_data() assert not dm.has_setup_fit @@ -246,7 +246,7 @@ def test_dm_init_from_argparse_args(tmpdir): assert dm.data_dir == args.data_dir == str(tmpdir) -def test_dm_pickle_after_init(tmpdir): +def test_dm_pickle_after_init(): dm = BoringDataModule() pickle.dumps(dm)