From af03f0a43462b637be8749b7678dc2c7c3c40a1d Mon Sep 17 00:00:00 2001 From: Minh Chien Vu <31467068+vumichien@users.noreply.github.com> Date: Tue, 19 Apr 2022 04:49:05 +0900 Subject: [PATCH] Remove the deprecated LightningDataModule.size, LightningDataModule.dims (#12780) --- CHANGELOG.md | 3 +++ pytorch_lightning/core/datamodule.py | 32 +------------------------ tests/deprecated_api/test_remove_1-7.py | 16 ------------- 3 files changed, 4 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f6a656aa1..3394388a6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,6 +100,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Remove deprecated `pytorch_lightning.callbacks.progress.progress` ([#12658](https://github.com/PyTorchLightning/pytorch-lightning/pull/12658)) +- Removed the deprecated `dim` and `size` arguments from the `LightningDataModule` constructor([#12780](https://github.com/PyTorchLightning/pytorch-lightning/pull/12780)) + + - Removed the deprecated `train_transforms` argument from the `LightningDataModule` constructor([#12662](https://github.com/PyTorchLightning/pytorch-lightning/pull/12662)) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index b3198a64cf..185eee6a9f 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -54,7 +54,7 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): name: str = ... - def __init__(self, val_transforms=None, test_transforms=None, dims=None): + def __init__(self, val_transforms=None, test_transforms=None): super().__init__() if val_transforms is not None: rank_zero_deprecation( @@ -64,11 +64,8 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): rank_zero_deprecation( "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7." ) - if dims is not None: - rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") self._val_transforms = val_transforms self._test_transforms = test_transforms - self._dims = dims if dims is not None else () # Pointer to the trainer object self.trainer = None @@ -111,33 +108,6 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): ) self._test_transforms = t - @property - def dims(self): - """A tuple describing the shape of your data. Extra functionality exposed in ``size``. - - .. deprecated:: v1.5 Will be removed in v1.7.0. - """ - rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") - return self._dims - - @dims.setter - def dims(self, d): - rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") - self._dims = d - - def size(self, dim=None) -> Union[Tuple, List[Tuple]]: - """Return the dimension of each input either as a tuple or list of tuples. You can index this just as you - would with a torch tensor. - - .. deprecated:: v1.5 Will be removed in v1.7.0. - """ - rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.") - - if dim is not None: - return self.dims[dim] - - return self.dims - @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: """Extends existing argparse by default `LightningDataModule` attributes.""" diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 88a8708ffe..d9ec403a7b 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -49,22 +49,6 @@ def test_v1_7_0_datamodule_transform_properties(tmpdir): _ = LightningDataModule(val_transforms="b") with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"): _ = LightningDataModule(test_transforms="c") - with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"): - _ = LightningDataModule(test_transforms="c", dims=(1, 1, 1)) - - -def test_v1_7_0_datamodule_size_property(tmpdir): - dm = MNISTDataModule() - with pytest.deprecated_call(match=r"DataModule property `size` was deprecated in v1.5"): - dm.size() - - -def test_v1_7_0_datamodule_dims_property(tmpdir): - dm = MNISTDataModule() - with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"): - _ = dm.dims - with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"): - _ = LightningDataModule(dims=(1, 1, 1)) def test_v1_7_0_moved_get_progress_bar_dict(tmpdir):