This reverts commit 6429de8944
.
This commit is contained in:
parent
6a9adf26f7
commit
dbfadedfe7
|
@ -187,9 +187,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
|
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
|
||||||
|
|
||||||
|
|
||||||
- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895))
|
|
||||||
|
|
||||||
|
|
||||||
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
|
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||||
from pytorch_lightning.core.mixins import HyperparametersMixin
|
from pytorch_lightning.core.mixins import HyperparametersMixin
|
||||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|
||||||
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
|
||||||
from pytorch_lightning.utilities.data import has_len
|
|
||||||
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
|
||||||
|
|
||||||
|
|
||||||
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
||||||
|
@ -484,40 +481,3 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
|
||||||
for fn in ("prepare_data", "setup", "teardown"):
|
for fn in ("prepare_data", "setup", "teardown"):
|
||||||
del d[fn]
|
del d[fn]
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Returns the total number of batches in all dataloaders defined in the datamodule."""
|
|
||||||
|
|
||||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
|
||||||
|
|
||||||
num_batches = 0
|
|
||||||
not_implemented_count = 0
|
|
||||||
|
|
||||||
def get_num_batches(dataloader: DataLoader, name: str) -> None:
|
|
||||||
nonlocal num_batches
|
|
||||||
if not has_len(dataloader):
|
|
||||||
rank_zero_warn(
|
|
||||||
f"The number of batches for a dataloader in `{name}` is counted as 0 "
|
|
||||||
"because it does not have `__len__` defined."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
num_batches += len(dataloader)
|
|
||||||
|
|
||||||
for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"):
|
|
||||||
dataloader_method = getattr(self, method_name)
|
|
||||||
if not callable(dataloader_method):
|
|
||||||
not_implemented_count += 1
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
dataloader = dataloader_method()
|
|
||||||
except NotImplementedError:
|
|
||||||
not_implemented_count += 1
|
|
||||||
continue
|
|
||||||
if isinstance(dataloader, CombinedLoader):
|
|
||||||
dataloader = dataloader.loaders
|
|
||||||
apply_to_collection(dataloader, DataLoader, get_num_batches, method_name)
|
|
||||||
|
|
||||||
if not_implemented_count == 4:
|
|
||||||
rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.")
|
|
||||||
|
|
||||||
return num_batches
|
|
||||||
|
|
|
@ -21,15 +21,13 @@ from unittest.mock import call, PropertyMock
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from pytorch_lightning import LightningDataModule, Trainer
|
from pytorch_lightning import LightningDataModule, Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
|
||||||
from pytorch_lightning.utilities import AttributeDict
|
from pytorch_lightning.utilities import AttributeDict
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||||
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
|
from tests.helpers import BoringDataModule, BoringModel
|
||||||
from tests.helpers.datamodules import ClassifDataModule
|
from tests.helpers.datamodules import ClassifDataModule
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
from tests.helpers.simple_models import ClassificationModel
|
from tests.helpers.simple_models import ClassificationModel
|
||||||
|
@ -566,14 +564,13 @@ def test_define_as_dataclass():
|
||||||
batch_size: int
|
batch_size: int
|
||||||
dims: int = 2
|
dims: int = 2
|
||||||
|
|
||||||
def train_dataloader(self):
|
def __post_init__(self):
|
||||||
return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size)
|
super().__init__(dims=self.dims)
|
||||||
|
|
||||||
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
|
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
|
||||||
# __repr__, __eq__, __lt__, __le__, etc.
|
# __repr__, __eq__, __lt__, __le__, etc.
|
||||||
assert BoringDataModule1(batch_size=64).dims == 2
|
assert BoringDataModule1(batch_size=64).dims == 2
|
||||||
assert BoringDataModule1(batch_size=32)
|
assert BoringDataModule1(batch_size=32)
|
||||||
assert len(BoringDataModule1(batch_size=32)) == 2
|
|
||||||
assert hasattr(BoringDataModule1, "__repr__")
|
assert hasattr(BoringDataModule1, "__repr__")
|
||||||
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
|
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
|
||||||
|
|
||||||
|
@ -584,9 +581,7 @@ def test_define_as_dataclass():
|
||||||
|
|
||||||
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
|
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
|
||||||
# __init__, __repr__, __eq__, __lt__, __le__, etc.
|
# __init__, __repr__, __eq__, __lt__, __le__, etc.
|
||||||
assert BoringDataModule2(batch_size=32) is not None
|
assert BoringDataModule2(batch_size=32)
|
||||||
assert BoringDataModule2(batch_size=32).batch_size == 32
|
|
||||||
assert len(BoringDataModule2(batch_size=32)) == 0
|
|
||||||
assert hasattr(BoringDataModule2, "__repr__")
|
assert hasattr(BoringDataModule2, "__repr__")
|
||||||
assert BoringDataModule2(batch_size=32).prepare_data() is None
|
assert BoringDataModule2(batch_size=32).prepare_data() is None
|
||||||
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
|
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
|
||||||
|
@ -630,76 +625,3 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
|
||||||
trainer.model = model
|
trainer.model = model
|
||||||
trainer.datamodule = dm
|
trainer.datamodule = dm
|
||||||
trainer._data_connector.prepare_data()
|
trainer._data_connector.prepare_data()
|
||||||
|
|
||||||
|
|
||||||
DATALOADER = DataLoader(RandomDataset(1, 32))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
["dataloader", "expected"],
|
|
||||||
[
|
|
||||||
[DATALOADER, 32],
|
|
||||||
[[DATALOADER, DATALOADER], 64],
|
|
||||||
[[[DATALOADER], [DATALOADER, DATALOADER]], 96],
|
|
||||||
[[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96],
|
|
||||||
[{"foo": DATALOADER, "bar": DATALOADER}, 64],
|
|
||||||
[{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96],
|
|
||||||
[{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96],
|
|
||||||
[CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64],
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_len_different_types(method_name, dataloader, expected):
|
|
||||||
dm = LightningDataModule()
|
|
||||||
setattr(dm, method_name, lambda: dataloader)
|
|
||||||
assert len(dm) == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
|
|
||||||
def test_len_dataloader_no_len(method_name):
|
|
||||||
class CustomNotImplementedErrorDataloader(DataLoader):
|
|
||||||
def __len__(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32))
|
|
||||||
dm = LightningDataModule()
|
|
||||||
setattr(dm, method_name, lambda: dataloader)
|
|
||||||
with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"):
|
|
||||||
assert len(dm) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_len_all_dataloader_methods_implemented():
|
|
||||||
class BoringDataModule(LightningDataModule):
|
|
||||||
def __init__(self, dataloader):
|
|
||||||
super().__init__()
|
|
||||||
self.dataloader = dataloader
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
|
||||||
return {"foo": self.dataloader, "bar": self.dataloader}
|
|
||||||
|
|
||||||
def val_dataloader(self):
|
|
||||||
return self.dataloader
|
|
||||||
|
|
||||||
def test_dataloader(self):
|
|
||||||
return [self.dataloader]
|
|
||||||
|
|
||||||
def predict_dataloader(self):
|
|
||||||
return [self.dataloader, self.dataloader]
|
|
||||||
|
|
||||||
dm = BoringDataModule(DATALOADER)
|
|
||||||
|
|
||||||
# 6 dataloaders each producing 32 batches: 6 * 32 = 192
|
|
||||||
assert len(dm) == 192
|
|
||||||
|
|
||||||
|
|
||||||
def test_len_no_dataloader_methods_implemented():
|
|
||||||
dm = LightningDataModule()
|
|
||||||
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
|
|
||||||
assert len(dm) == 0
|
|
||||||
|
|
||||||
dm.train_dataloader = None
|
|
||||||
dm.val_dataloader = None
|
|
||||||
dm.test_dataloader = None
|
|
||||||
dm.predict_dataloader = None
|
|
||||||
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
|
|
||||||
assert len(dm) == 0
|
|
||||||
|
|
Loading…
Reference in New Issue