diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index f1e6f3e04e..660834b489 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -131,13 +131,13 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper): name: str = ... def __init__( - self, train_transforms=None, val_transforms=None, test_transforms=None, + self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None ): super().__init__() self._train_transforms = train_transforms self._val_transforms = val_transforms self._test_transforms = test_transforms - self.dims = () + self._dims = dims if dims is not None else () # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False @@ -177,9 +177,21 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper): def test_transforms(self, t): self._test_transforms = t + @property + def dims(self): + """ + A tuple describing the shape of your data. Extra functionality exposed in ``size``. + """ + return self._dims + + @dims.setter + def dims(self, d): + self._dims = d + def size(self, dim=None) -> Union[Tuple, int]: """ - Return the dimension of each input either as a tuple or list of tuples. + Return the dimension of each input either as a tuple or list of tuples. You can index this + just as you would with a torch tensor. """ if dim is not None: