Automatically check `DataModule.has_{setup,teardown,prepare_data}` [2/2] (#7238)
* Automatically check `DataModule.has_{setup,teardown,prepare_data}` * Use variable * Spacing * Docs * Update CHANGELOG * Remove `_DataModuleWrapper` * Add test * Update docs/source/extensions/datamodules.rst * Bad merge * add test for invalid name * Remove ValueError Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
8660d8cf03
commit
b65ae79478
|
@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
|
||||
|
||||
|
||||
- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))
|
||||
|
||||
|
||||
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))
|
||||
|
||||
|
||||
|
|
|
@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
|
|||
def test_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=32)
|
||||
|
||||
|
||||
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.
|
||||
|
||||
|
||||
---------------
|
||||
|
||||
LightningDataModule API
|
||||
|
@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup
|
|||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
# Assign Train/val split(s) for use in Dataloaders
|
||||
if stage == 'fit' or stage is None:
|
||||
if stage in (None, 'fit'):
|
||||
mnist_full = MNIST(
|
||||
self.data_dir,
|
||||
train=True,
|
||||
|
@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup
|
|||
self.dims = self.mnist_train[0][0].shape
|
||||
|
||||
# Assign Test split(s) for use in Dataloaders
|
||||
if stage == 'test' or stage is None:
|
||||
if stage in (None, 'test'):
|
||||
self.mnist_test = MNIST(
|
||||
self.data_dir,
|
||||
train=False,
|
||||
|
@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup
|
|||
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
|
||||
|
||||
|
||||
.. warning:: ``setup`` is called from every process. Setting state here is okay.
|
||||
|
||||
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument.
|
||||
It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``,
|
||||
we assume all stages have been set-up.
|
||||
|
||||
.. note:: ``setup`` is called from every process. Setting state here is okay.
|
||||
.. note:: ``teardown`` can be used to clean up the state. It is also called from every process
|
||||
.. note::
|
||||
``{setup,teardown,prepare_data}`` call will be only called once for a specific stage.
|
||||
If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that
|
||||
any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite
|
||||
``dm._has_setup_fit = False``
|
||||
|
||||
|
||||
train_dataloader
|
||||
|
@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply:
|
|||
dm = MNISTDataModule()
|
||||
model = Model()
|
||||
trainer.fit(model, dm)
|
||||
|
||||
trainer.test(datamodule=dm)
|
||||
|
||||
If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning
|
||||
still ensures the method runs on the correct devices)
|
||||
If you need information from the dataset to build your model, then run
|
||||
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and
|
||||
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures
|
||||
the method runs on the correct devices).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -416,7 +420,7 @@ still ensures the method runs on the correct devices)
|
|||
|
||||
----------------
|
||||
|
||||
Datamodules without Lightning
|
||||
DataModules without Lightning
|
||||
-----------------------------
|
||||
You can of course use DataModules in plain PyTorch code as well.
|
||||
|
||||
|
|
|
@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo
|
|||
1. use ``prepare_data()`` to download and process the dataset.
|
||||
2. use ``setup()`` to do splits, and build your model internals
|
||||
|
||||
|
|
||||
|
||||
An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows:
|
||||
|
||||
.. testcode::
|
||||
|
|
|
@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
|
|||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
# split dataset
|
||||
if stage == 'fit':
|
||||
if stage in (None, 'fit'):
|
||||
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
|
||||
if stage == 'test':
|
||||
if stage == (None, 'test'):
|
||||
self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
|
||||
|
||||
# return the dataloader for each split
|
||||
|
|
|
@ -355,6 +355,7 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
@functools.wraps(fn)
|
||||
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
|
||||
name = fn.__name__
|
||||
has_run = False
|
||||
|
||||
# If calling setup, we check the stage and assign stage-specific bool args
|
||||
if name in ("setup", "teardown"):
|
||||
|
@ -366,15 +367,22 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
stage = args[0] if len(args) else kwargs.get("stage", None)
|
||||
|
||||
if stage is None:
|
||||
has_run = True
|
||||
for s in ("fit", "validate", "test"):
|
||||
setattr(obj, f"_has_{name}_{s}", True)
|
||||
attr = f"_has_{name}_{s}"
|
||||
has_run &= getattr(obj, attr)
|
||||
setattr(obj, attr, True)
|
||||
else:
|
||||
setattr(obj, f"_has_{name}_{stage}", True)
|
||||
attr = f"_has_{name}_{stage}"
|
||||
has_run = getattr(obj, attr)
|
||||
setattr(obj, attr, True)
|
||||
|
||||
elif name == "prepare_data":
|
||||
has_run = obj._has_prepared_data
|
||||
obj._has_prepared_data = True
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
if not has_run:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
|
|
@ -394,7 +394,7 @@ class DataHooks:
|
|||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
"""
|
||||
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
|
||||
Called at the beginning of fit (train + validate), validate, test, and predict.
|
||||
This is a good hook when you need to build models dynamically or adjust something about them.
|
||||
This hook is called on every process when using DDP.
|
||||
|
||||
|
|
|
@ -1156,10 +1156,7 @@ class Trainer(
|
|||
self.accelerator.barrier("pre_setup")
|
||||
|
||||
if self.datamodule is not None:
|
||||
called = getattr(self.datamodule, f'has_setup_{fn}')
|
||||
if not called:
|
||||
self.datamodule.setup(stage=fn)
|
||||
|
||||
self.datamodule.setup(stage=fn)
|
||||
self.setup(model, stage=fn)
|
||||
model.setup(stage=fn)
|
||||
|
||||
|
@ -1182,10 +1179,7 @@ class Trainer(
|
|||
fn = self.state.fn._setup_fn
|
||||
|
||||
if self.datamodule is not None:
|
||||
called = getattr(self.datamodule, f'has_teardown_{fn}')
|
||||
if not called:
|
||||
self.datamodule.teardown(stage=fn)
|
||||
|
||||
self.datamodule.teardown(stage=fn)
|
||||
self.profiler.teardown(stage=fn)
|
||||
self.teardown(stage=fn)
|
||||
model.teardown(stage=fn)
|
||||
|
|
|
@ -521,3 +521,46 @@ def test_dm_init_from_datasets_dataloaders(iterable):
|
|||
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
|
||||
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
|
||||
])
|
||||
|
||||
|
||||
def test_datamodule_hooks_calls(tmpdir):
|
||||
"""Test that repeated calls to DataHooks' hooks have no effect"""
|
||||
|
||||
class TestDataModule(BoringDataModule):
|
||||
setup_calls = []
|
||||
teardown_calls = []
|
||||
prepare_data_calls = 0
|
||||
|
||||
def setup(self, stage=None):
|
||||
super().setup(stage=stage)
|
||||
self.setup_calls.append(stage)
|
||||
|
||||
def teardown(self, stage=None):
|
||||
super().teardown(stage=stage)
|
||||
self.teardown_calls.append(stage)
|
||||
|
||||
def prepare_data(self):
|
||||
super().prepare_data()
|
||||
self.prepare_data_calls += 1
|
||||
|
||||
dm = TestDataModule()
|
||||
dm.prepare_data()
|
||||
dm.prepare_data()
|
||||
dm.setup('fit')
|
||||
dm.setup('fit')
|
||||
dm.setup()
|
||||
dm.setup()
|
||||
dm.teardown('validate')
|
||||
dm.teardown('validate')
|
||||
|
||||
assert dm.prepare_data_calls == 1
|
||||
assert dm.setup_calls == ['fit', None]
|
||||
assert dm.teardown_calls == ['validate']
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
||||
trainer.test(BoringModel(), datamodule=dm)
|
||||
|
||||
# same number of calls
|
||||
assert dm.prepare_data_calls == 1
|
||||
assert dm.setup_calls == ['fit', None]
|
||||
assert dm.teardown_calls == ['validate', 'test']
|
||||
|
|
Loading…
Reference in New Issue