lightning/tests/core/test_datamodules.py

699 lines
24 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from typing import Any, Dict
from unittest import mock
from unittest.mock import call, PropertyMock
import pytest
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
from tests.helpers.utils import reset_seed
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank):
dm = BoringDataModule()
trainer = Trainer()
trainer.datamodule = dm
# 1 no DM
# prepare_data_per_node = True
# local rank = 0 (True)
dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 0
assert trainer.local_rank == 0
trainer.data_connector.prepare_data()
assert dm.random_full is not None
# local rank = 1 (False)
dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 1
assert trainer.local_rank == 1
trainer.data_connector.prepare_data()
assert dm.random_full is None
# prepare_data_per_node = False (prepare across all nodes)
# global rank = 0 (True)
dm.random_full = None
dm._has_prepared_data = False
dm.prepare_data_per_node = False
node_rank.return_value = 0
local_rank.return_value = 0
trainer.data_connector.prepare_data()
assert dm.random_full is not None
# global rank = 1 (False)
dm.random_full = None
dm._has_prepared_data = False
node_rank.return_value = 1
local_rank.return_value = 0
trainer.data_connector.prepare_data()
assert dm.random_full is None
node_rank.return_value = 0
local_rank.return_value = 1
trainer.data_connector.prepare_data()
assert dm.random_full is None
# 2 dm
# prepar per node = True
# local rank = 0 (True)
dm.prepare_data_per_node = True
local_rank.return_value = 0
with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock:
# is_overridden prepare data = True
# has been called
# False
dm._has_prepared_data = True
trainer.data_connector.prepare_data()
dm_mock.assert_not_called()
# has not been called
# True
dm._has_prepared_data = False
trainer.data_connector.prepare_data()
dm_mock.assert_called_once()
def test_hooks_no_recursion_error():
# hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error.
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652
class DummyDM(LightningDataModule):
def setup(self, *args, **kwargs):
pass
def prepare_data(self, *args, **kwargs):
pass
for i in range(1005):
dm = DummyDM()
dm.setup()
dm.prepare_data()
def test_helper_boringdatamodule():
dm = BoringDataModule()
dm.prepare_data()
dm.setup()
def test_helper_boringdatamodule_with_verbose_setup():
dm = BoringDataModule()
dm.prepare_data()
dm.setup("fit")
dm.setup("test")
def test_data_hooks_called():
dm = BoringDataModule()
assert not dm.has_prepared_data
assert not dm.has_setup_fit
assert not dm.has_setup_test
assert not dm.has_setup_validate
assert not dm.has_setup_predict
assert not dm.has_teardown_fit
assert not dm.has_teardown_test
assert not dm.has_teardown_validate
assert not dm.has_teardown_predict
dm.prepare_data()
assert dm.has_prepared_data
assert not dm.has_setup_fit
assert not dm.has_setup_test
assert not dm.has_setup_validate
assert not dm.has_setup_predict
assert not dm.has_teardown_fit
assert not dm.has_teardown_test
assert not dm.has_teardown_validate
assert not dm.has_teardown_predict
dm.setup()
assert dm.has_prepared_data
assert dm.has_setup_fit
assert dm.has_setup_test
assert dm.has_setup_validate
assert not dm.has_setup_predict
assert not dm.has_teardown_fit
assert not dm.has_teardown_test
assert not dm.has_teardown_validate
assert not dm.has_teardown_predict
dm.teardown()
assert dm.has_prepared_data
assert dm.has_setup_fit
assert dm.has_setup_test
assert dm.has_setup_validate
assert not dm.has_setup_predict
assert dm.has_teardown_fit
assert dm.has_teardown_test
assert dm.has_teardown_validate
assert not dm.has_teardown_predict
@pytest.mark.parametrize("use_kwarg", (False, True))
def test_data_hooks_called_verbose(use_kwarg):
dm = BoringDataModule()
dm.prepare_data()
assert not dm.has_setup_fit
assert not dm.has_setup_test
assert not dm.has_setup_validate
assert not dm.has_setup_predict
assert not dm.has_teardown_fit
assert not dm.has_teardown_test
assert not dm.has_teardown_validate
assert not dm.has_teardown_predict
dm.setup(stage="fit") if use_kwarg else dm.setup("fit")
assert dm.has_setup_fit
assert not dm.has_setup_validate
assert not dm.has_setup_test
assert not dm.has_setup_predict
dm.setup(stage="validate") if use_kwarg else dm.setup("validate")
assert dm.has_setup_fit
assert dm.has_setup_validate
assert not dm.has_setup_test
assert not dm.has_setup_predict
dm.setup(stage="test") if use_kwarg else dm.setup("test")
assert dm.has_setup_fit
assert dm.has_setup_validate
assert dm.has_setup_test
assert not dm.has_setup_predict
dm.setup(stage="predict") if use_kwarg else dm.setup("predict")
assert dm.has_setup_fit
assert dm.has_setup_validate
assert dm.has_setup_test
assert dm.has_setup_predict
dm.teardown(stage="fit") if use_kwarg else dm.teardown("fit")
assert dm.has_teardown_fit
assert not dm.has_teardown_validate
assert not dm.has_teardown_test
assert not dm.has_teardown_predict
dm.teardown(stage="validate") if use_kwarg else dm.teardown("validate")
assert dm.has_teardown_fit
assert dm.has_teardown_validate
assert not dm.has_teardown_test
assert not dm.has_teardown_predict
dm.teardown(stage="test") if use_kwarg else dm.teardown("test")
assert dm.has_teardown_fit
assert dm.has_teardown_validate
assert dm.has_teardown_test
assert not dm.has_teardown_predict
dm.teardown(stage="predict") if use_kwarg else dm.teardown("predict")
assert dm.has_teardown_fit
assert dm.has_teardown_validate
assert dm.has_teardown_test
assert dm.has_teardown_predict
def test_dm_add_argparse_args(tmpdir):
parser = ArgumentParser()
parser = BoringDataModule.add_argparse_args(parser)
args = parser.parse_args(["--data_dir", str(tmpdir)])
assert args.data_dir == str(tmpdir)
def test_dm_init_from_argparse_args(tmpdir):
parser = ArgumentParser()
parser = BoringDataModule.add_argparse_args(parser)
args = parser.parse_args(["--data_dir", str(tmpdir)])
dm = BoringDataModule.from_argparse_args(args)
dm.prepare_data()
dm.setup()
assert dm.data_dir == args.data_dir == str(tmpdir)
def test_dm_pickle_after_init():
dm = BoringDataModule()
pickle.dumps(dm)
def test_train_loop_only(tmpdir):
reset_seed()
dm = ClassifDataModule()
model = ClassificationModel()
model.validation_step = None
model.validation_step_end = None
model.validation_epoch_end = None
model.test_step = None
model.test_step_end = None
model.test_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False)
# fit model
trainer.fit(model, datamodule=dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.callback_metrics["train_loss"] < 1.0
def test_train_val_loop_only(tmpdir):
reset_seed()
dm = ClassifDataModule()
model = ClassificationModel()
model.validation_step = None
model.validation_step_end = None
model.validation_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False)
# fit model
trainer.fit(model, datamodule=dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.callback_metrics["train_loss"] < 1.0
def test_dm_checkpoint_save(tmpdir):
class CustomBoringModel(BoringModel):
def validation_step(self, batch, batch_idx):
out = super().validation_step(batch, batch_idx)
self.log("early_stop_on", out["x"])
return out
class CustomBoringDataModule(BoringDataModule):
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint[self.__class__.__name__] = self.__class__.__name__
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
reset_seed()
dm = CustomBoringDataModule()
model = CustomBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=1,
enable_model_summary=False,
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on")],
)
# fit model
trainer.fit(model, dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
checkpoint = torch.load(checkpoint_path)
assert dm.__class__.__name__ in checkpoint
assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__
def test_full_loop(tmpdir):
reset_seed()
dm = ClassifDataModule()
model = ClassificationModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False, deterministic=True)
# fit model
trainer.fit(model, dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert dm.trainer is not None
# validate
result = trainer.validate(model, dm)
assert dm.trainer is not None
assert result[0]["val_acc"] > 0.7
# test
result = trainer.test(model, dm)
assert dm.trainer is not None
assert result[0]["test_acc"] > 0.6
@RunIf(min_gpus=1)
@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock)
def test_dm_apply_batch_transfer_handler(get_module_mock):
expected_device = torch.device("cuda", 0)
class CustomBatch:
def __init__(self, data):
self.samples = data[0]
self.targets = data[1]
class CurrentTestDM(LightningDataModule):
rank = 0
transfer_batch_to_device_hook_rank = None
on_before_batch_transfer_hook_rank = None
on_after_batch_transfer_hook_rank = None
def on_before_batch_transfer(self, batch, dataloader_idx):
assert dataloader_idx == 0
self.on_before_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.samples += 1
return batch
def on_after_batch_transfer(self, batch, dataloader_idx):
assert dataloader_idx == 0
assert batch.samples.device == batch.targets.device == expected_device
self.on_after_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.targets *= 2
return batch
def transfer_batch_to_device(self, batch, device, dataloader_idx):
assert dataloader_idx == 0
self.transfer_batch_to_device_hook_rank = self.rank
self.rank += 1
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
return batch
dm = CurrentTestDM()
model = BoringModel()
batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long)))
trainer = Trainer(gpus=1)
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
get_module_mock.return_value = model
if is_overridden("transfer_batch_to_device", dm):
model.transfer_batch_to_device = dm.transfer_batch_to_device
model.on_before_batch_transfer = dm.on_before_batch_transfer
model.transfer_batch_to_device = dm.transfer_batch_to_device
model.on_after_batch_transfer = dm.on_after_batch_transfer
batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device)
assert dm.on_before_batch_transfer_hook_rank == 0
assert dm.transfer_batch_to_device_hook_rank == 1
assert dm.on_after_batch_transfer_hook_rank == 2
assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device
assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32))
assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2)
def test_dm_reload_dataloaders_every_n_epochs(tmpdir):
"""Test datamodule, where trainer argument reload_dataloaders_every_n_epochs is set to a non negative
integer."""
class CustomBoringDataModule(BoringDataModule):
def __init__(self):
super().__init__()
self._epochs_called_for = []
def train_dataloader(self):
assert self.trainer.current_epoch not in self._epochs_called_for
self._epochs_called_for.append(self.trainer.current_epoch)
return super().train_dataloader()
dm = CustomBoringDataModule()
model = BoringModel()
model.validation_step = None
model.validation_step_end = None
model.validation_epoch_end = None
model.test_step = None
model.test_step_end = None
model.test_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, limit_train_batches=2, reload_dataloaders_every_n_epochs=2)
trainer.fit(model, dm)
class DummyDS(torch.utils.data.Dataset):
def __getitem__(self, index):
return 1
def __len__(self):
return 100
class DummyIDS(torch.utils.data.IterableDataset):
def __iter__(self):
yield 1
@pytest.mark.parametrize("iterable", (False, True))
def test_dm_init_from_datasets_dataloaders(iterable):
ds = DummyIDS if iterable else DummyDS
train_ds = ds()
dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0)
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
dm.train_dataloader()
dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
with pytest.raises(NotImplementedError):
_ = dm.val_dataloader()
with pytest.raises(NotImplementedError):
_ = dm.test_dataloader()
train_ds_sequence = [ds(), ds()]
dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0)
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
dm.train_dataloader()
dl_mock.assert_has_calls(
[
call(train_ds_sequence[0], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
]
)
with pytest.raises(NotImplementedError):
_ = dm.val_dataloader()
with pytest.raises(NotImplementedError):
_ = dm.test_dataloader()
valid_ds = ds()
test_ds = ds()
dm = LightningDataModule.from_datasets(val_dataset=valid_ds, test_dataset=test_ds, batch_size=2, num_workers=0)
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
dm.val_dataloader()
dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
dm.test_dataloader()
dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
with pytest.raises(NotImplementedError):
_ = dm.train_dataloader()
valid_dss = [ds(), ds()]
test_dss = [ds(), ds()]
dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0)
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
dm.val_dataloader()
dm.test_dataloader()
dl_mock.assert_has_calls(
[
call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
]
)
# all args
class DataModuleWithHparams_0(LightningDataModule):
def __init__(self, arg0, arg1, kwarg0=None):
super().__init__()
self.save_hyperparameters()
# single arg
class DataModuleWithHparams_1(LightningDataModule):
def __init__(self, arg0, *args, **kwargs):
super().__init__()
self.save_hyperparameters(arg0)
def test_hyperparameters_saving():
data = DataModuleWithHparams_0(10, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"})
data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})
data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})
def test_define_as_dataclass():
# makes sure that no functionality is broken and the user can still manually make
# super().__init__ call with parameters
# also tests all the dataclass features that can be enabled without breaking anything
@dataclass(init=True, repr=True, eq=True, order=True, unsafe_hash=True, frozen=False)
class BoringDataModule1(LightningDataModule):
batch_size: int
dims: int = 2
def train_dataloader(self):
return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size)
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
# __repr__, __eq__, __lt__, __le__, etc.
assert BoringDataModule1(batch_size=64).dims == 2
assert BoringDataModule1(batch_size=32)
assert len(BoringDataModule1(batch_size=32)) == 2
assert hasattr(BoringDataModule1, "__repr__")
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
# asserts inherent calling of super().__init__ in case user doesn't make the call
@dataclass
class BoringDataModule2(LightningDataModule):
batch_size: int
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
# __init__, __repr__, __eq__, __lt__, __le__, etc.
assert BoringDataModule2(batch_size=32) is not None
assert BoringDataModule2(batch_size=32).batch_size == 32
assert len(BoringDataModule2(batch_size=32)) == 0
assert hasattr(BoringDataModule2, "__repr__")
assert BoringDataModule2(batch_size=32).prepare_data() is None
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
# checking for all the different multilevel inhertiance scenarios, for init call on LightningDataModule
@dataclass
class BoringModuleBase1(LightningDataModule):
num_features: int
class BoringModuleBase2(LightningDataModule):
def __init__(self, num_features: int):
self.num_features = num_features
@dataclass
class BoringModuleDerived1(BoringModuleBase1):
...
class BoringModuleDerived2(BoringModuleBase1):
def __init__(self):
...
@dataclass
class BoringModuleDerived3(BoringModuleBase2):
...
class BoringModuleDerived4(BoringModuleBase2):
def __init__(self):
...
assert hasattr(BoringModuleDerived1(num_features=2), "_has_prepared_data")
assert hasattr(BoringModuleDerived2(), "_has_prepared_data")
assert hasattr(BoringModuleDerived3(), "_has_prepared_data")
assert hasattr(BoringModuleDerived4(), "_has_prepared_data")
def test_inconsistent_prepare_data_per_node(tmpdir):
with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."):
model = BoringModel()
dm = BoringDataModule()
trainer = Trainer(prepare_data_per_node=False)
trainer.model = model
trainer.datamodule = dm
trainer.data_connector.prepare_data()
DATALOADER = DataLoader(RandomDataset(1, 32))
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
@pytest.mark.parametrize(
["dataloader", "expected"],
[
[DATALOADER, 32],
[[DATALOADER, DATALOADER], 64],
[[[DATALOADER], [DATALOADER, DATALOADER]], 96],
[[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96],
[{"foo": DATALOADER, "bar": DATALOADER}, 64],
[{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96],
[{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96],
[CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64],
],
)
def test_len_different_types(method_name, dataloader, expected):
dm = LightningDataModule()
setattr(dm, method_name, lambda: dataloader)
assert len(dm) == expected
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
def test_len_dataloader_no_len(method_name):
class CustomNotImplementedErrorDataloader(DataLoader):
def __len__(self):
raise NotImplementedError
dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32))
dm = LightningDataModule()
setattr(dm, method_name, lambda: dataloader)
with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"):
assert len(dm) == 0
def test_len_all_dataloader_methods_implemented():
class BoringDataModule(LightningDataModule):
def __init__(self, dataloader):
super().__init__()
self.dataloader = dataloader
def train_dataloader(self):
return {"foo": self.dataloader, "bar": self.dataloader}
def val_dataloader(self):
return self.dataloader
def test_dataloader(self):
return [self.dataloader]
def predict_dataloader(self):
return [self.dataloader, self.dataloader]
dm = BoringDataModule(DATALOADER)
# 6 dataloaders each producing 32 batches: 6 * 32 = 192
assert len(dm) == 192
def test_len_no_dataloader_methods_implemented():
dm = LightningDataModule()
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
assert len(dm) == 0