diff --git a/CHANGELOG.md b/CHANGELOG.md index 40435d6650..89fcc8ace1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -187,9 +187,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848)) -- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895)) - - - Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699)) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 56a97bd45f..f3a5c855fe 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -22,10 +22,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -from pytorch_lightning.utilities.data import has_len -from pytorch_lightning.utilities.warnings import rank_zero_warn class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): @@ -484,40 +481,3 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): for fn in ("prepare_data", "setup", "teardown"): del d[fn] return d - - def __len__(self) -> int: - """Returns the total number of batches in all dataloaders defined in the datamodule.""" - - from pytorch_lightning.trainer.supporters import CombinedLoader - - num_batches = 0 - not_implemented_count = 0 - - def get_num_batches(dataloader: DataLoader, name: str) -> None: - nonlocal num_batches - if not has_len(dataloader): - rank_zero_warn( - f"The number of batches for a dataloader in `{name}` is counted as 0 " - "because it does not have `__len__` defined." - ) - else: - num_batches += len(dataloader) - - for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"): - dataloader_method = getattr(self, method_name) - if not callable(dataloader_method): - not_implemented_count += 1 - continue - try: - dataloader = dataloader_method() - except NotImplementedError: - not_implemented_count += 1 - continue - if isinstance(dataloader, CombinedLoader): - dataloader = dataloader.loaders - apply_to_collection(dataloader, DataLoader, get_num_batches, method_name) - - if not_implemented_count == 4: - rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.") - - return num_batches diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 868c13bcc1..539767ac2d 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -21,15 +21,13 @@ from unittest.mock import call, PropertyMock import pytest import torch from omegaconf import OmegaConf -from torch.utils.data import DataLoader from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from tests.helpers import BoringDataModule, BoringModel, RandomDataset +from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel @@ -566,14 +564,13 @@ def test_define_as_dataclass(): batch_size: int dims: int = 2 - def train_dataloader(self): - return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size) + def __post_init__(self): + super().__init__(dims=self.dims) # asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e. # __repr__, __eq__, __lt__, __le__, etc. assert BoringDataModule1(batch_size=64).dims == 2 assert BoringDataModule1(batch_size=32) - assert len(BoringDataModule1(batch_size=32)) == 2 assert hasattr(BoringDataModule1, "__repr__") assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32) @@ -584,9 +581,7 @@ def test_define_as_dataclass(): # asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e. # __init__, __repr__, __eq__, __lt__, __le__, etc. - assert BoringDataModule2(batch_size=32) is not None - assert BoringDataModule2(batch_size=32).batch_size == 32 - assert len(BoringDataModule2(batch_size=32)) == 0 + assert BoringDataModule2(batch_size=32) assert hasattr(BoringDataModule2, "__repr__") assert BoringDataModule2(batch_size=32).prepare_data() is None assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32) @@ -630,76 +625,3 @@ def test_inconsistent_prepare_data_per_node(tmpdir): trainer.model = model trainer.datamodule = dm trainer._data_connector.prepare_data() - - -DATALOADER = DataLoader(RandomDataset(1, 32)) - - -@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]) -@pytest.mark.parametrize( - ["dataloader", "expected"], - [ - [DATALOADER, 32], - [[DATALOADER, DATALOADER], 64], - [[[DATALOADER], [DATALOADER, DATALOADER]], 96], - [[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96], - [{"foo": DATALOADER, "bar": DATALOADER}, 64], - [{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96], - [{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96], - [CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64], - ], -) -def test_len_different_types(method_name, dataloader, expected): - dm = LightningDataModule() - setattr(dm, method_name, lambda: dataloader) - assert len(dm) == expected - - -@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]) -def test_len_dataloader_no_len(method_name): - class CustomNotImplementedErrorDataloader(DataLoader): - def __len__(self): - raise NotImplementedError - - dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32)) - dm = LightningDataModule() - setattr(dm, method_name, lambda: dataloader) - with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"): - assert len(dm) == 0 - - -def test_len_all_dataloader_methods_implemented(): - class BoringDataModule(LightningDataModule): - def __init__(self, dataloader): - super().__init__() - self.dataloader = dataloader - - def train_dataloader(self): - return {"foo": self.dataloader, "bar": self.dataloader} - - def val_dataloader(self): - return self.dataloader - - def test_dataloader(self): - return [self.dataloader] - - def predict_dataloader(self): - return [self.dataloader, self.dataloader] - - dm = BoringDataModule(DATALOADER) - - # 6 dataloaders each producing 32 batches: 6 * 32 = 192 - assert len(dm) == 192 - - -def test_len_no_dataloader_methods_implemented(): - dm = LightningDataModule() - with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"): - assert len(dm) == 0 - - dm.train_dataloader = None - dm.val_dataloader = None - dm.test_dataloader = None - dm.predict_dataloader = None - with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"): - assert len(dm) == 0