This reverts commit 6429de8944
.
This commit is contained in:
parent
6a9adf26f7
commit
dbfadedfe7
|
@ -187,9 +187,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
|
||||
|
||||
|
||||
- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895))
|
||||
|
||||
|
||||
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
|
||||
|
||||
|
||||
|
|
|
@ -22,10 +22,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset
|
|||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||
from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
||||
from pytorch_lightning.utilities.data import has_len
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
||||
|
||||
|
||||
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
||||
|
@ -484,40 +481,3 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
|||
for fn in ("prepare_data", "setup", "teardown"):
|
||||
del d[fn]
|
||||
return d
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the total number of batches in all dataloaders defined in the datamodule."""
|
||||
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
|
||||
num_batches = 0
|
||||
not_implemented_count = 0
|
||||
|
||||
def get_num_batches(dataloader: DataLoader, name: str) -> None:
|
||||
nonlocal num_batches
|
||||
if not has_len(dataloader):
|
||||
rank_zero_warn(
|
||||
f"The number of batches for a dataloader in `{name}` is counted as 0 "
|
||||
"because it does not have `__len__` defined."
|
||||
)
|
||||
else:
|
||||
num_batches += len(dataloader)
|
||||
|
||||
for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"):
|
||||
dataloader_method = getattr(self, method_name)
|
||||
if not callable(dataloader_method):
|
||||
not_implemented_count += 1
|
||||
continue
|
||||
try:
|
||||
dataloader = dataloader_method()
|
||||
except NotImplementedError:
|
||||
not_implemented_count += 1
|
||||
continue
|
||||
if isinstance(dataloader, CombinedLoader):
|
||||
dataloader = dataloader.loaders
|
||||
apply_to_collection(dataloader, DataLoader, get_num_batches, method_name)
|
||||
|
||||
if not_implemented_count == 4:
|
||||
rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.")
|
||||
|
||||
return num_batches
|
||||
|
|
|
@ -21,15 +21,13 @@ from unittest.mock import call, PropertyMock
|
|||
import pytest
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import LightningDataModule, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
from tests.helpers.datamodules import ClassifDataModule
|
||||
from tests.helpers.runif import RunIf
|
||||
from tests.helpers.simple_models import ClassificationModel
|
||||
|
@ -566,14 +564,13 @@ def test_define_as_dataclass():
|
|||
batch_size: int
|
||||
dims: int = 2
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size)
|
||||
def __post_init__(self):
|
||||
super().__init__(dims=self.dims)
|
||||
|
||||
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
|
||||
# __repr__, __eq__, __lt__, __le__, etc.
|
||||
assert BoringDataModule1(batch_size=64).dims == 2
|
||||
assert BoringDataModule1(batch_size=32)
|
||||
assert len(BoringDataModule1(batch_size=32)) == 2
|
||||
assert hasattr(BoringDataModule1, "__repr__")
|
||||
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
|
||||
|
||||
|
@ -584,9 +581,7 @@ def test_define_as_dataclass():
|
|||
|
||||
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
|
||||
# __init__, __repr__, __eq__, __lt__, __le__, etc.
|
||||
assert BoringDataModule2(batch_size=32) is not None
|
||||
assert BoringDataModule2(batch_size=32).batch_size == 32
|
||||
assert len(BoringDataModule2(batch_size=32)) == 0
|
||||
assert BoringDataModule2(batch_size=32)
|
||||
assert hasattr(BoringDataModule2, "__repr__")
|
||||
assert BoringDataModule2(batch_size=32).prepare_data() is None
|
||||
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
|
||||
|
@ -630,76 +625,3 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
|
|||
trainer.model = model
|
||||
trainer.datamodule = dm
|
||||
trainer._data_connector.prepare_data()
|
||||
|
||||
|
||||
DATALOADER = DataLoader(RandomDataset(1, 32))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
|
||||
@pytest.mark.parametrize(
|
||||
["dataloader", "expected"],
|
||||
[
|
||||
[DATALOADER, 32],
|
||||
[[DATALOADER, DATALOADER], 64],
|
||||
[[[DATALOADER], [DATALOADER, DATALOADER]], 96],
|
||||
[[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96],
|
||||
[{"foo": DATALOADER, "bar": DATALOADER}, 64],
|
||||
[{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96],
|
||||
[{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96],
|
||||
[CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64],
|
||||
],
|
||||
)
|
||||
def test_len_different_types(method_name, dataloader, expected):
|
||||
dm = LightningDataModule()
|
||||
setattr(dm, method_name, lambda: dataloader)
|
||||
assert len(dm) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
|
||||
def test_len_dataloader_no_len(method_name):
|
||||
class CustomNotImplementedErrorDataloader(DataLoader):
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32))
|
||||
dm = LightningDataModule()
|
||||
setattr(dm, method_name, lambda: dataloader)
|
||||
with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"):
|
||||
assert len(dm) == 0
|
||||
|
||||
|
||||
def test_len_all_dataloader_methods_implemented():
|
||||
class BoringDataModule(LightningDataModule):
|
||||
def __init__(self, dataloader):
|
||||
super().__init__()
|
||||
self.dataloader = dataloader
|
||||
|
||||
def train_dataloader(self):
|
||||
return {"foo": self.dataloader, "bar": self.dataloader}
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def test_dataloader(self):
|
||||
return [self.dataloader]
|
||||
|
||||
def predict_dataloader(self):
|
||||
return [self.dataloader, self.dataloader]
|
||||
|
||||
dm = BoringDataModule(DATALOADER)
|
||||
|
||||
# 6 dataloaders each producing 32 batches: 6 * 32 = 192
|
||||
assert len(dm) == 192
|
||||
|
||||
|
||||
def test_len_no_dataloader_methods_implemented():
|
||||
dm = LightningDataModule()
|
||||
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
|
||||
assert len(dm) == 0
|
||||
|
||||
dm.train_dataloader = None
|
||||
dm.val_dataloader = None
|
||||
dm.test_dataloader = None
|
||||
dm.predict_dataloader = None
|
||||
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
|
||||
assert len(dm) == 0
|
||||
|
|
Loading…
Reference in New Issue