Deprecate `LightningDataModule` lifecycle properties (#7657)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
ananthsub 2021-06-09 05:30:40 -07:00 committed by GitHub
parent 764d2c775e
commit 6fee9262ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 47 deletions

View File

@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated `DataModule` properties: `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test`, `has_teardown_predict` ([#7657](https://github.com/PyTorchLightning/pytorch-lightning/pull/7657/))
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))

View File

@ -20,8 +20,8 @@ from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
from torch.utils.data import DataLoader, Dataset, IterableDataset
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only
class LightningDataModule(CheckpointHooks, DataHooks):
@ -160,7 +160,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True if ``datamodule.prepare_data()`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_prepared_data
@property
@ -169,7 +175,11 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation('DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.')
return self._has_setup_fit
@property
@ -178,7 +188,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_setup_validate` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_setup_validate
@property
@ -187,7 +203,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True if ``datamodule.setup(stage='test')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_setup_test` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_setup_test
@property
@ -196,7 +218,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_setup_predict` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_setup_predict
@property
@ -205,7 +233,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_teardown_fit` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_teardown_fit
@property
@ -214,7 +248,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_teardown_validate
@property
@ -223,7 +263,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_teardown_test` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_teardown_test
@property
@ -232,7 +278,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
Returns:
bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default.
.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_teardown_predict` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_teardown_predict
@classmethod
@ -381,8 +433,13 @@ class LightningDataModule(CheckpointHooks, DataHooks):
has_run = obj._has_prepared_data
obj._has_prepared_data = True
if not has_run:
return fn(*args, **kwargs)
if has_run:
rank_zero_deprecation(
f"DataModule.{name} has already been called, so it will not be called again. "
f"In v1.6 this behavior will change to always call DataModule.{name}."
)
else:
fn(*args, **kwargs)
return wrapped_fn

View File

@ -524,46 +524,3 @@ def test_dm_init_from_datasets_dataloaders(iterable):
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
])
def test_datamodule_hooks_calls(tmpdir):
"""Test that repeated calls to DataHooks' hooks have no effect"""
class TestDataModule(BoringDataModule):
setup_calls = []
teardown_calls = []
prepare_data_calls = 0
def setup(self, stage=None):
super().setup(stage=stage)
self.setup_calls.append(stage)
def teardown(self, stage=None):
super().teardown(stage=stage)
self.teardown_calls.append(stage)
def prepare_data(self):
super().prepare_data()
self.prepare_data_calls += 1
dm = TestDataModule()
dm.prepare_data()
dm.prepare_data()
dm.setup('fit')
dm.setup('fit')
dm.setup()
dm.setup()
dm.teardown('validate')
dm.teardown('validate')
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate']
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.test(BoringModel(), datamodule=dm)
# same number of calls
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate', 'test']

View File

@ -16,7 +16,7 @@ import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from tests.helpers import BoringModel
from tests.helpers import BoringDataModule, BoringModel
def test_v1_6_0_trainer_model_hook_mixin(tmpdir):
@ -86,3 +86,76 @@ def test_v1_6_0_tbptt_pad_token(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.deprecated_call(match=r"tbptt_pad_token=...\)` is no longer supported"):
trainer.fit(TestModel())
def test_v1_6_0_datamodule_lifecycle_properties(tmpdir):
dm = BoringDataModule()
with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"):
dm.has_prepared_data
with pytest.deprecated_call(match=r"DataModule property `has_setup_fit` was deprecated in v1.4"):
dm.has_setup_fit
with pytest.deprecated_call(match=r"DataModule property `has_setup_validate` was deprecated in v1.4"):
dm.has_setup_validate
with pytest.deprecated_call(match=r"DataModule property `has_setup_test` was deprecated in v1.4"):
dm.has_setup_test
with pytest.deprecated_call(match=r"DataModule property `has_setup_predict` was deprecated in v1.4"):
dm.has_setup_predict
with pytest.deprecated_call(match=r"DataModule property `has_teardown_fit` was deprecated in v1.4"):
dm.has_teardown_fit
with pytest.deprecated_call(match=r"DataModule property `has_teardown_validate` was deprecated in v1.4"):
dm.has_teardown_validate
with pytest.deprecated_call(match=r"DataModule property `has_teardown_test` was deprecated in v1.4"):
dm.has_teardown_test
with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"):
dm.has_teardown_predict
def test_v1_6_0_datamodule_hooks_calls(tmpdir):
"""Test that repeated calls to DataHooks' hooks show a warning about the coming API change."""
class TestDataModule(BoringDataModule):
setup_calls = []
teardown_calls = []
prepare_data_calls = 0
def setup(self, stage=None):
super().setup(stage=stage)
self.setup_calls.append(stage)
def teardown(self, stage=None):
super().teardown(stage=stage)
self.teardown_calls.append(stage)
def prepare_data(self):
super().prepare_data()
self.prepare_data_calls += 1
dm = TestDataModule()
dm.prepare_data()
dm.prepare_data()
dm.setup('fit')
with pytest.deprecated_call(
match=r"DataModule.setup has already been called, so it will not be called again. "
"In v1.6 this behavior will change to always call DataModule.setup"
):
dm.setup('fit')
dm.setup()
dm.setup()
dm.teardown('validate')
with pytest.deprecated_call(
match=r"DataModule.teardown has already been called, so it will not be called again. "
"In v1.6 this behavior will change to always call DataModule.teardown"
):
dm.teardown('validate')
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate']
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.test(BoringModel(), datamodule=dm)
# same number of calls
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate', 'test']