Deprecate `LightningDataModule` lifecycle properties (#7657)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
764d2c775e
commit
6fee9262ff
|
@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Deprecated
|
||||
|
||||
|
||||
- Deprecated `DataModule` properties: `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test`, `has_teardown_predict` ([#7657](https://github.com/PyTorchLightning/pytorch-lightning/pull/7657/))
|
||||
|
||||
|
||||
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))
|
||||
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
|
|||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only
|
||||
|
||||
|
||||
class LightningDataModule(CheckpointHooks, DataHooks):
|
||||
|
@ -160,7 +160,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True if ``datamodule.prepare_data()`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'
|
||||
)
|
||||
return self._has_prepared_data
|
||||
|
||||
@property
|
||||
|
@ -169,7 +175,11 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation('DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.')
|
||||
return self._has_setup_fit
|
||||
|
||||
@property
|
||||
|
@ -178,7 +188,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
'DataModule property `has_setup_validate` was deprecated in v1.4 and will be removed in v1.6.'
|
||||
)
|
||||
return self._has_setup_validate
|
||||
|
||||
@property
|
||||
|
@ -187,7 +203,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True if ``datamodule.setup(stage='test')`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
'DataModule property `has_setup_test` was deprecated in v1.4 and will be removed in v1.6.'
|
||||
)
|
||||
return self._has_setup_test
|
||||
|
||||
@property
|
||||
|
@ -196,7 +218,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
'DataModule property `has_setup_predict` was deprecated in v1.4 and will be removed in v1.6.'
|
||||
)
|
||||
return self._has_setup_predict
|
||||
|
||||
@property
|
||||
|
@ -205,7 +233,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
'DataModule property `has_teardown_fit` was deprecated in v1.4 and will be removed in v1.6.'
|
||||
)
|
||||
return self._has_teardown_fit
|
||||
|
||||
@property
|
||||
|
@ -214,7 +248,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
'DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6.'
|
||||
)
|
||||
return self._has_teardown_validate
|
||||
|
||||
@property
|
||||
|
@ -223,7 +263,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
'DataModule property `has_teardown_test` was deprecated in v1.4 and will be removed in v1.6.'
|
||||
)
|
||||
return self._has_teardown_test
|
||||
|
||||
@property
|
||||
|
@ -232,7 +278,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
|
||||
Returns:
|
||||
bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default.
|
||||
|
||||
.. deprecated:: v1.4
|
||||
Will be removed in v1.6.0.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
'DataModule property `has_teardown_predict` was deprecated in v1.4 and will be removed in v1.6.'
|
||||
)
|
||||
return self._has_teardown_predict
|
||||
|
||||
@classmethod
|
||||
|
@ -381,8 +433,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
|
|||
has_run = obj._has_prepared_data
|
||||
obj._has_prepared_data = True
|
||||
|
||||
if not has_run:
|
||||
return fn(*args, **kwargs)
|
||||
if has_run:
|
||||
rank_zero_deprecation(
|
||||
f"DataModule.{name} has already been called, so it will not be called again. "
|
||||
f"In v1.6 this behavior will change to always call DataModule.{name}."
|
||||
)
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
|
|
@ -524,46 +524,3 @@ def test_dm_init_from_datasets_dataloaders(iterable):
|
|||
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
|
||||
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
|
||||
])
|
||||
|
||||
|
||||
def test_datamodule_hooks_calls(tmpdir):
|
||||
"""Test that repeated calls to DataHooks' hooks have no effect"""
|
||||
|
||||
class TestDataModule(BoringDataModule):
|
||||
setup_calls = []
|
||||
teardown_calls = []
|
||||
prepare_data_calls = 0
|
||||
|
||||
def setup(self, stage=None):
|
||||
super().setup(stage=stage)
|
||||
self.setup_calls.append(stage)
|
||||
|
||||
def teardown(self, stage=None):
|
||||
super().teardown(stage=stage)
|
||||
self.teardown_calls.append(stage)
|
||||
|
||||
def prepare_data(self):
|
||||
super().prepare_data()
|
||||
self.prepare_data_calls += 1
|
||||
|
||||
dm = TestDataModule()
|
||||
dm.prepare_data()
|
||||
dm.prepare_data()
|
||||
dm.setup('fit')
|
||||
dm.setup('fit')
|
||||
dm.setup()
|
||||
dm.setup()
|
||||
dm.teardown('validate')
|
||||
dm.teardown('validate')
|
||||
|
||||
assert dm.prepare_data_calls == 1
|
||||
assert dm.setup_calls == ['fit', None]
|
||||
assert dm.teardown_calls == ['validate']
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
||||
trainer.test(BoringModel(), datamodule=dm)
|
||||
|
||||
# same number of calls
|
||||
assert dm.prepare_data_calls == 1
|
||||
assert dm.setup_calls == ['fit', None]
|
||||
assert dm.teardown_calls == ['validate', 'test']
|
||||
|
|
|
@ -16,7 +16,7 @@ import pytest
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
|
||||
|
||||
def test_v1_6_0_trainer_model_hook_mixin(tmpdir):
|
||||
|
@ -86,3 +86,76 @@ def test_v1_6_0_tbptt_pad_token(tmpdir):
|
|||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
with pytest.deprecated_call(match=r"tbptt_pad_token=...\)` is no longer supported"):
|
||||
trainer.fit(TestModel())
|
||||
|
||||
|
||||
def test_v1_6_0_datamodule_lifecycle_properties(tmpdir):
|
||||
dm = BoringDataModule()
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"):
|
||||
dm.has_prepared_data
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_setup_fit` was deprecated in v1.4"):
|
||||
dm.has_setup_fit
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_setup_validate` was deprecated in v1.4"):
|
||||
dm.has_setup_validate
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_setup_test` was deprecated in v1.4"):
|
||||
dm.has_setup_test
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_setup_predict` was deprecated in v1.4"):
|
||||
dm.has_setup_predict
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_teardown_fit` was deprecated in v1.4"):
|
||||
dm.has_teardown_fit
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_teardown_validate` was deprecated in v1.4"):
|
||||
dm.has_teardown_validate
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_teardown_test` was deprecated in v1.4"):
|
||||
dm.has_teardown_test
|
||||
with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"):
|
||||
dm.has_teardown_predict
|
||||
|
||||
|
||||
def test_v1_6_0_datamodule_hooks_calls(tmpdir):
|
||||
"""Test that repeated calls to DataHooks' hooks show a warning about the coming API change."""
|
||||
|
||||
class TestDataModule(BoringDataModule):
|
||||
setup_calls = []
|
||||
teardown_calls = []
|
||||
prepare_data_calls = 0
|
||||
|
||||
def setup(self, stage=None):
|
||||
super().setup(stage=stage)
|
||||
self.setup_calls.append(stage)
|
||||
|
||||
def teardown(self, stage=None):
|
||||
super().teardown(stage=stage)
|
||||
self.teardown_calls.append(stage)
|
||||
|
||||
def prepare_data(self):
|
||||
super().prepare_data()
|
||||
self.prepare_data_calls += 1
|
||||
|
||||
dm = TestDataModule()
|
||||
dm.prepare_data()
|
||||
dm.prepare_data()
|
||||
dm.setup('fit')
|
||||
with pytest.deprecated_call(
|
||||
match=r"DataModule.setup has already been called, so it will not be called again. "
|
||||
"In v1.6 this behavior will change to always call DataModule.setup"
|
||||
):
|
||||
dm.setup('fit')
|
||||
dm.setup()
|
||||
dm.setup()
|
||||
dm.teardown('validate')
|
||||
with pytest.deprecated_call(
|
||||
match=r"DataModule.teardown has already been called, so it will not be called again. "
|
||||
"In v1.6 this behavior will change to always call DataModule.teardown"
|
||||
):
|
||||
dm.teardown('validate')
|
||||
|
||||
assert dm.prepare_data_calls == 1
|
||||
assert dm.setup_calls == ['fit', None]
|
||||
assert dm.teardown_calls == ['validate']
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
||||
trainer.test(BoringModel(), datamodule=dm)
|
||||
|
||||
# same number of calls
|
||||
assert dm.prepare_data_calls == 1
|
||||
assert dm.setup_calls == ['fit', None]
|
||||
assert dm.teardown_calls == ['validate', 'test']
|
||||
|
|
Loading…
Reference in New Issue