Automatically check `DataModule.has_{setup,teardown,prepare_data}` [2/2] (#7238)

* Automatically check `DataModule.has_{setup,teardown,prepare_data}`

* Use variable

* Spacing

* Docs

* Update CHANGELOG

* Remove `_DataModuleWrapper`

* Add test

* Update docs/source/extensions/datamodules.rst

* Bad merge

* add test for invalid name

* Remove ValueError

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-05-11 10:53:00 +02:00 committed by GitHub
parent 8660d8cf03
commit b65ae79478
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 78 additions and 28 deletions

View File

@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))

View File

@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.
---------------
LightningDataModule API
@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup
def setup(self, stage: Optional[str] = None):
# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
if stage in (None, 'fit'):
mnist_full = MNIST(
self.data_dir,
train=True,
@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = self.mnist_train[0][0].shape
# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
if stage in (None, 'test'):
self.mnist_test = MNIST(
self.data_dir,
train=False,
@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
.. warning:: ``setup`` is called from every process. Setting state here is okay.
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument.
It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``,
we assume all stages have been set-up.
.. note:: ``setup`` is called from every process. Setting state here is okay.
.. note:: ``teardown`` can be used to clean up the state. It is also called from every process
.. note::
``{setup,teardown,prepare_data}`` call will be only called once for a specific stage.
If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that
any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite
``dm._has_setup_fit = False``
train_dataloader
@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply:
dm = MNISTDataModule()
model = Model()
trainer.fit(model, dm)
trainer.test(datamodule=dm)
If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning
still ensures the method runs on the correct devices)
If you need information from the dataset to build your model, then run
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures
the method runs on the correct devices).
.. code-block:: python
@ -416,7 +420,7 @@ still ensures the method runs on the correct devices)
----------------
Datamodules without Lightning
DataModules without Lightning
-----------------------------
You can of course use DataModules in plain PyTorch code as well.

View File

@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo
1. use ``prepare_data()`` to download and process the dataset.
2. use ``setup()`` to do splits, and build your model internals
|
An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows:
.. testcode::

View File

@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
transforms.Normalize((0.1307,), (0.3081,))
])
# split dataset
if stage == 'fit':
if stage in (None, 'fit'):
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test':
if stage == (None, 'test'):
self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
# return the dataloader for each split

View File

@ -355,6 +355,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
@functools.wraps(fn)
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
name = fn.__name__
has_run = False
# If calling setup, we check the stage and assign stage-specific bool args
if name in ("setup", "teardown"):
@ -366,15 +367,22 @@ class LightningDataModule(CheckpointHooks, DataHooks):
stage = args[0] if len(args) else kwargs.get("stage", None)
if stage is None:
has_run = True
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_{name}_{s}", True)
attr = f"_has_{name}_{s}"
has_run &= getattr(obj, attr)
setattr(obj, attr, True)
else:
setattr(obj, f"_has_{name}_{stage}", True)
attr = f"_has_{name}_{stage}"
has_run = getattr(obj, attr)
setattr(obj, attr, True)
elif name == "prepare_data":
has_run = obj._has_prepared_data
obj._has_prepared_data = True
return fn(*args, **kwargs)
if not has_run:
return fn(*args, **kwargs)
return wrapped_fn

View File

@ -394,7 +394,7 @@ class DataHooks:
def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
Called at the beginning of fit (train + validate), validate, test, and predict.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

View File

@ -1156,10 +1156,7 @@ class Trainer(
self.accelerator.barrier("pre_setup")
if self.datamodule is not None:
called = getattr(self.datamodule, f'has_setup_{fn}')
if not called:
self.datamodule.setup(stage=fn)
self.datamodule.setup(stage=fn)
self.setup(model, stage=fn)
model.setup(stage=fn)
@ -1182,10 +1179,7 @@ class Trainer(
fn = self.state.fn._setup_fn
if self.datamodule is not None:
called = getattr(self.datamodule, f'has_teardown_{fn}')
if not called:
self.datamodule.teardown(stage=fn)
self.datamodule.teardown(stage=fn)
self.profiler.teardown(stage=fn)
self.teardown(stage=fn)
model.teardown(stage=fn)

View File

@ -521,3 +521,46 @@ def test_dm_init_from_datasets_dataloaders(iterable):
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
])
def test_datamodule_hooks_calls(tmpdir):
"""Test that repeated calls to DataHooks' hooks have no effect"""
class TestDataModule(BoringDataModule):
setup_calls = []
teardown_calls = []
prepare_data_calls = 0
def setup(self, stage=None):
super().setup(stage=stage)
self.setup_calls.append(stage)
def teardown(self, stage=None):
super().teardown(stage=stage)
self.teardown_calls.append(stage)
def prepare_data(self):
super().prepare_data()
self.prepare_data_calls += 1
dm = TestDataModule()
dm.prepare_data()
dm.prepare_data()
dm.setup('fit')
dm.setup('fit')
dm.setup()
dm.setup()
dm.teardown('validate')
dm.teardown('validate')
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate']
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.test(BoringModel(), datamodule=dm)
# same number of calls
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate', 'test']