From b65ae7947803e7dc2a254af1045c1bcc6e4cf147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 11 May 2021 10:53:00 +0200 Subject: [PATCH] Automatically check `DataModule.has_{setup,teardown,prepare_data}` [2/2] (#7238) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Automatically check `DataModule.has_{setup,teardown,prepare_data}` * Use variable * Spacing * Docs * Update CHANGELOG * Remove `_DataModuleWrapper` * Add test * Update docs/source/extensions/datamodules.rst * Bad merge * add test for invalid name * Remove ValueError Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 ++ docs/source/extensions/datamodules.rst | 28 ++++++++------ docs/source/starter/introduction_guide.rst | 2 - docs/source/starter/new-project.rst | 4 +- pytorch_lightning/core/datamodule.py | 14 +++++-- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/trainer/trainer.py | 10 +---- tests/core/test_datamodules.py | 43 ++++++++++++++++++++++ 8 files changed, 78 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4cf25d533..6c9228de1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025)) +- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) + + - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index a602a75b0f..fbb19e10a8 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) - -.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. - - --------------- LightningDataModule API @@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup def setup(self, stage: Optional[str] = None): # Assign Train/val split(s) for use in Dataloaders - if stage == 'fit' or stage is None: + if stage in (None, 'fit'): mnist_full = MNIST( self.data_dir, train=True, @@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = self.mnist_train[0][0].shape # Assign Test split(s) for use in Dataloaders - if stage == 'test' or stage is None: + if stage in (None, 'test'): self.mnist_test = MNIST( self.data_dir, train=False, @@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) -.. warning:: ``setup`` is called from every process. Setting state here is okay. - +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument. +It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``, +we assume all stages have been set-up. +.. note:: ``setup`` is called from every process. Setting state here is okay. .. note:: ``teardown`` can be used to clean up the state. It is also called from every process +.. note:: + ``{setup,teardown,prepare_data}`` call will be only called once for a specific stage. + If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that + any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite + ``dm._has_setup_fit = False`` train_dataloader @@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply: dm = MNISTDataModule() model = Model() trainer.fit(model, dm) - trainer.test(datamodule=dm) -If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning -still ensures the method runs on the correct devices) +If you need information from the dataset to build your model, then run +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures +the method runs on the correct devices). .. code-block:: python @@ -416,7 +420,7 @@ still ensures the method runs on the correct devices) ---------------- -Datamodules without Lightning +DataModules without Lightning ----------------------------- You can of course use DataModules in plain PyTorch code as well. diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 3a19d0f346..3286495660 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo 1. use ``prepare_data()`` to download and process the dataset. 2. use ``setup()`` to do splits, and build your model internals -| - An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows: .. testcode:: diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index fb3e0152c4..74ad30102b 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning. transforms.Normalize((0.1307,), (0.3081,)) ]) # split dataset - if stage == 'fit': + if stage in (None, 'fit'): mnist_train = MNIST(os.getcwd(), train=True, transform=transform) self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000]) - if stage == 'test': + if stage == (None, 'test'): self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform) # return the dataloader for each split diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 4eaed6c7b2..23626ed9cb 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -355,6 +355,7 @@ class LightningDataModule(CheckpointHooks, DataHooks): @functools.wraps(fn) def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: name = fn.__name__ + has_run = False # If calling setup, we check the stage and assign stage-specific bool args if name in ("setup", "teardown"): @@ -366,15 +367,22 @@ class LightningDataModule(CheckpointHooks, DataHooks): stage = args[0] if len(args) else kwargs.get("stage", None) if stage is None: + has_run = True for s in ("fit", "validate", "test"): - setattr(obj, f"_has_{name}_{s}", True) + attr = f"_has_{name}_{s}" + has_run &= getattr(obj, attr) + setattr(obj, attr, True) else: - setattr(obj, f"_has_{name}_{stage}", True) + attr = f"_has_{name}_{stage}" + has_run = getattr(obj, attr) + setattr(obj, attr, True) elif name == "prepare_data": + has_run = obj._has_prepared_data obj._has_prepared_data = True - return fn(*args, **kwargs) + if not has_run: + return fn(*args, **kwargs) return wrapped_fn diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 7ab0c8acbe..7cc74f3d04 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -394,7 +394,7 @@ class DataHooks: def setup(self, stage: Optional[str] = None) -> None: """ - Called at the beginning of fit (train + validate), validate, test, predict, or tune. + Called at the beginning of fit (train + validate), validate, test, and predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8e37d73ef7..3e700c9a3c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1156,10 +1156,7 @@ class Trainer( self.accelerator.barrier("pre_setup") if self.datamodule is not None: - called = getattr(self.datamodule, f'has_setup_{fn}') - if not called: - self.datamodule.setup(stage=fn) - + self.datamodule.setup(stage=fn) self.setup(model, stage=fn) model.setup(stage=fn) @@ -1182,10 +1179,7 @@ class Trainer( fn = self.state.fn._setup_fn if self.datamodule is not None: - called = getattr(self.datamodule, f'has_teardown_{fn}') - if not called: - self.datamodule.teardown(stage=fn) - + self.datamodule.teardown(stage=fn) self.profiler.teardown(stage=fn) self.teardown(stage=fn) model.teardown(stage=fn) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 0041ccb52c..7cfa569115 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -521,3 +521,46 @@ def test_dm_init_from_datasets_dataloaders(iterable): call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True) ]) + + +def test_datamodule_hooks_calls(tmpdir): + """Test that repeated calls to DataHooks' hooks have no effect""" + + class TestDataModule(BoringDataModule): + setup_calls = [] + teardown_calls = [] + prepare_data_calls = 0 + + def setup(self, stage=None): + super().setup(stage=stage) + self.setup_calls.append(stage) + + def teardown(self, stage=None): + super().teardown(stage=stage) + self.teardown_calls.append(stage) + + def prepare_data(self): + super().prepare_data() + self.prepare_data_calls += 1 + + dm = TestDataModule() + dm.prepare_data() + dm.prepare_data() + dm.setup('fit') + dm.setup('fit') + dm.setup() + dm.setup() + dm.teardown('validate') + dm.teardown('validate') + + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate'] + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + trainer.test(BoringModel(), datamodule=dm) + + # same number of calls + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate', 'test']