Make dims a property in datamodule (#3547)

* 🐛 make dims a property

* 🐛 fix
This commit is contained in:
Nathan Raw 2020-09-18 15:30:49 -06:00 committed by GitHub
parent 07006238ea
commit 00ab67edc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 3 deletions

View File

@ -131,13 +131,13 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
name: str = ...
def __init__(
self, train_transforms=None, val_transforms=None, test_transforms=None,
self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None
):
super().__init__()
self._train_transforms = train_transforms
self._val_transforms = val_transforms
self._test_transforms = test_transforms
self.dims = ()
self._dims = dims if dims is not None else ()
# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
@ -177,9 +177,21 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
def test_transforms(self, t):
self._test_transforms = t
@property
def dims(self):
"""
A tuple describing the shape of your data. Extra functionality exposed in ``size``.
"""
return self._dims
@dims.setter
def dims(self, d):
self._dims = d
def size(self, dim=None) -> Union[Tuple, int]:
"""
Return the dimension of each input either as a tuple or list of tuples.
Return the dimension of each input either as a tuple or list of tuples. You can index this
just as you would with a torch tensor.
"""
if dim is not None: