Make dims a property in datamodule (#3547)
* 🐛 make dims a property * 🐛 fix
This commit is contained in:
parent
07006238ea
commit
00ab67edc3
|
@ -131,13 +131,13 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
|||
name: str = ...
|
||||
|
||||
def __init__(
|
||||
self, train_transforms=None, val_transforms=None, test_transforms=None,
|
||||
self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None
|
||||
):
|
||||
super().__init__()
|
||||
self._train_transforms = train_transforms
|
||||
self._val_transforms = val_transforms
|
||||
self._test_transforms = test_transforms
|
||||
self.dims = ()
|
||||
self._dims = dims if dims is not None else ()
|
||||
|
||||
# Private attrs to keep track of whether or not data hooks have been called yet
|
||||
self._has_prepared_data = False
|
||||
|
@ -177,9 +177,21 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
|||
def test_transforms(self, t):
|
||||
self._test_transforms = t
|
||||
|
||||
@property
|
||||
def dims(self):
|
||||
"""
|
||||
A tuple describing the shape of your data. Extra functionality exposed in ``size``.
|
||||
"""
|
||||
return self._dims
|
||||
|
||||
@dims.setter
|
||||
def dims(self, d):
|
||||
self._dims = d
|
||||
|
||||
def size(self, dim=None) -> Union[Tuple, int]:
|
||||
"""
|
||||
Return the dimension of each input either as a tuple or list of tuples.
|
||||
Return the dimension of each input either as a tuple or list of tuples. You can index this
|
||||
just as you would with a torch tensor.
|
||||
"""
|
||||
|
||||
if dim is not None:
|
||||
|
|
Loading…
Reference in New Issue