diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 770745b322..e2ba05eee3 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -184,7 +184,7 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") self._dims = d - def size(self, dim=None) -> Union[Tuple, int]: + def size(self, dim=None) -> Union[Tuple, List[Tuple]]: """ Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor.