Remove the deprecated LightningDataModule.size, LightningDataModule.dims (#12780)
This commit is contained in:
parent
0b22e51462
commit
af03f0a434
|
@ -100,6 +100,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Remove deprecated `pytorch_lightning.callbacks.progress.progress` ([#12658](https://github.com/PyTorchLightning/pytorch-lightning/pull/12658))
|
||||
|
||||
|
||||
- Removed the deprecated `dim` and `size` arguments from the `LightningDataModule` constructor([#12780](https://github.com/PyTorchLightning/pytorch-lightning/pull/12780))
|
||||
|
||||
|
||||
- Removed the deprecated `train_transforms` argument from the `LightningDataModule` constructor([#12662](https://github.com/PyTorchLightning/pytorch-lightning/pull/12662))
|
||||
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
|||
|
||||
name: str = ...
|
||||
|
||||
def __init__(self, val_transforms=None, test_transforms=None, dims=None):
|
||||
def __init__(self, val_transforms=None, test_transforms=None):
|
||||
super().__init__()
|
||||
if val_transforms is not None:
|
||||
rank_zero_deprecation(
|
||||
|
@ -64,11 +64,8 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
|||
rank_zero_deprecation(
|
||||
"DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
|
||||
)
|
||||
if dims is not None:
|
||||
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
|
||||
self._val_transforms = val_transforms
|
||||
self._test_transforms = test_transforms
|
||||
self._dims = dims if dims is not None else ()
|
||||
|
||||
# Pointer to the trainer object
|
||||
self.trainer = None
|
||||
|
@ -111,33 +108,6 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
|||
)
|
||||
self._test_transforms = t
|
||||
|
||||
@property
|
||||
def dims(self):
|
||||
"""A tuple describing the shape of your data. Extra functionality exposed in ``size``.
|
||||
|
||||
.. deprecated:: v1.5 Will be removed in v1.7.0.
|
||||
"""
|
||||
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
|
||||
return self._dims
|
||||
|
||||
@dims.setter
|
||||
def dims(self, d):
|
||||
rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
|
||||
self._dims = d
|
||||
|
||||
def size(self, dim=None) -> Union[Tuple, List[Tuple]]:
|
||||
"""Return the dimension of each input either as a tuple or list of tuples. You can index this just as you
|
||||
would with a torch tensor.
|
||||
|
||||
.. deprecated:: v1.5 Will be removed in v1.7.0.
|
||||
"""
|
||||
rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.")
|
||||
|
||||
if dim is not None:
|
||||
return self.dims[dim]
|
||||
|
||||
return self.dims
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
|
||||
"""Extends existing argparse by default `LightningDataModule` attributes."""
|
||||
|
|
|
@ -49,22 +49,6 @@ def test_v1_7_0_datamodule_transform_properties(tmpdir):
|
|||
_ = LightningDataModule(val_transforms="b")
|
||||
with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"):
|
||||
_ = LightningDataModule(test_transforms="c")
|
||||
with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"):
|
||||
_ = LightningDataModule(test_transforms="c", dims=(1, 1, 1))
|
||||
|
||||
|
||||
def test_v1_7_0_datamodule_size_property(tmpdir):
|
||||
dm = MNISTDataModule()
|
||||
with pytest.deprecated_call(match=r"DataModule property `size` was deprecated in v1.5"):
|
||||
dm.size()
|
||||
|
||||
|
||||
def test_v1_7_0_datamodule_dims_property(tmpdir):
|
||||
dm = MNISTDataModule()
|
||||
with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"):
|
||||
_ = dm.dims
|
||||
with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"):
|
||||
_ = LightningDataModule(dims=(1, 1, 1))
|
||||
|
||||
|
||||
def test_v1_7_0_moved_get_progress_bar_dict(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue