Revert "Add support for `len(datamodule)` (#9895)" (#10072)

This reverts commit 6429de8944.
This commit is contained in:
Ning 2021-10-29 04:33:51 -07:00 committed by GitHub
parent 6a9adf26f7
commit dbfadedfe7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 125 deletions

View File

@ -187,9 +187,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895))
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))

View File

@ -22,10 +22,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.warnings import rank_zero_warn
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
@ -484,40 +481,3 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
for fn in ("prepare_data", "setup", "teardown"):
del d[fn]
return d
def __len__(self) -> int:
"""Returns the total number of batches in all dataloaders defined in the datamodule."""
from pytorch_lightning.trainer.supporters import CombinedLoader
num_batches = 0
not_implemented_count = 0
def get_num_batches(dataloader: DataLoader, name: str) -> None:
nonlocal num_batches
if not has_len(dataloader):
rank_zero_warn(
f"The number of batches for a dataloader in `{name}` is counted as 0 "
"because it does not have `__len__` defined."
)
else:
num_batches += len(dataloader)
for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"):
dataloader_method = getattr(self, method_name)
if not callable(dataloader_method):
not_implemented_count += 1
continue
try:
dataloader = dataloader_method()
except NotImplementedError:
not_implemented_count += 1
continue
if isinstance(dataloader, CombinedLoader):
dataloader = dataloader.loaders
apply_to_collection(dataloader, DataLoader, get_num_batches, method_name)
if not_implemented_count == 4:
rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.")
return num_batches

View File

@ -21,15 +21,13 @@ from unittest.mock import call, PropertyMock
import pytest
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
@ -566,14 +564,13 @@ def test_define_as_dataclass():
batch_size: int
dims: int = 2
def train_dataloader(self):
return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size)
def __post_init__(self):
super().__init__(dims=self.dims)
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
# __repr__, __eq__, __lt__, __le__, etc.
assert BoringDataModule1(batch_size=64).dims == 2
assert BoringDataModule1(batch_size=32)
assert len(BoringDataModule1(batch_size=32)) == 2
assert hasattr(BoringDataModule1, "__repr__")
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
@ -584,9 +581,7 @@ def test_define_as_dataclass():
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
# __init__, __repr__, __eq__, __lt__, __le__, etc.
assert BoringDataModule2(batch_size=32) is not None
assert BoringDataModule2(batch_size=32).batch_size == 32
assert len(BoringDataModule2(batch_size=32)) == 0
assert BoringDataModule2(batch_size=32)
assert hasattr(BoringDataModule2, "__repr__")
assert BoringDataModule2(batch_size=32).prepare_data() is None
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
@ -630,76 +625,3 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
trainer.model = model
trainer.datamodule = dm
trainer._data_connector.prepare_data()
DATALOADER = DataLoader(RandomDataset(1, 32))
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
@pytest.mark.parametrize(
["dataloader", "expected"],
[
[DATALOADER, 32],
[[DATALOADER, DATALOADER], 64],
[[[DATALOADER], [DATALOADER, DATALOADER]], 96],
[[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96],
[{"foo": DATALOADER, "bar": DATALOADER}, 64],
[{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96],
[{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96],
[CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64],
],
)
def test_len_different_types(method_name, dataloader, expected):
dm = LightningDataModule()
setattr(dm, method_name, lambda: dataloader)
assert len(dm) == expected
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
def test_len_dataloader_no_len(method_name):
class CustomNotImplementedErrorDataloader(DataLoader):
def __len__(self):
raise NotImplementedError
dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32))
dm = LightningDataModule()
setattr(dm, method_name, lambda: dataloader)
with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"):
assert len(dm) == 0
def test_len_all_dataloader_methods_implemented():
class BoringDataModule(LightningDataModule):
def __init__(self, dataloader):
super().__init__()
self.dataloader = dataloader
def train_dataloader(self):
return {"foo": self.dataloader, "bar": self.dataloader}
def val_dataloader(self):
return self.dataloader
def test_dataloader(self):
return [self.dataloader]
def predict_dataloader(self):
return [self.dataloader, self.dataloader]
dm = BoringDataModule(DATALOADER)
# 6 dataloaders each producing 32 batches: 6 * 32 = 192
assert len(dm) == 192
def test_len_no_dataloader_methods_implemented():
dm = LightningDataModule()
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
assert len(dm) == 0
dm.train_dataloader = None
dm.val_dataloader = None
dm.test_dataloader = None
dm.predict_dataloader = None
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
assert len(dm) == 0