# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pickle from argparse import ArgumentParser, Namespace from dataclasses import dataclass from typing import Any, Dict from unittest import mock from unittest.mock import call, Mock, PropertyMock import pytest import torch from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel from tests.helpers.utils import reset_seed if _OMEGACONF_AVAILABLE: from omegaconf import OmegaConf @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): dm = Mock(spec=LightningDataModule) dm.prepare_data_per_node = True trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) dm.prepare_data.assert_not_called() local_rank.return_value = 0 assert trainer.local_rank == 0 trainer._data_connector.prepare_data() dm.prepare_data.assert_called_once() # local rank = 1 (False) dm.reset_mock() local_rank.return_value = 1 assert trainer.local_rank == 1 trainer._data_connector.prepare_data() dm.prepare_data.assert_not_called() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) dm.reset_mock() dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 trainer._data_connector.prepare_data() dm.prepare_data.assert_called_once() # global rank = 1 (False) dm.reset_mock() node_rank.return_value = 1 local_rank.return_value = 0 trainer._data_connector.prepare_data() dm.prepare_data.assert_not_called() node_rank.return_value = 0 local_rank.return_value = 1 trainer._data_connector.prepare_data() dm.prepare_data.assert_not_called() # 2 dm # prepar per node = True # local rank = 0 (True) dm.prepare_data_per_node = True local_rank.return_value = 0 # is_overridden prepare data = True trainer._data_connector.prepare_data() dm.prepare_data.assert_called_once() def test_hooks_no_recursion_error(): # hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error. # See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652 class DummyDM(LightningDataModule): def setup(self, *args, **kwargs): pass def prepare_data(self, *args, **kwargs): pass for i in range(1005): dm = DummyDM() dm.setup() dm.prepare_data() def test_helper_boringdatamodule(): dm = BoringDataModule() dm.prepare_data() dm.setup() def test_helper_boringdatamodule_with_verbose_setup(): dm = BoringDataModule() dm.prepare_data() dm.setup("fit") dm.setup("test") def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() parser = BoringDataModule.add_argparse_args(parser) args = parser.parse_args(["--data_dir", str(tmpdir)]) assert args.data_dir == str(tmpdir) def test_dm_init_from_argparse_args(tmpdir): parser = ArgumentParser() parser = BoringDataModule.add_argparse_args(parser) args = parser.parse_args(["--data_dir", str(tmpdir)]) dm = BoringDataModule.from_argparse_args(args) dm.prepare_data() dm.setup() assert dm.data_dir == args.data_dir == str(tmpdir) def test_dm_pickle_after_init(): dm = BoringDataModule() pickle.dumps(dm) def test_train_loop_only(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() model.validation_step = None model.validation_step_end = None model.validation_epoch_end = None model.test_step = None model.test_step_end = None model.test_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False) # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.callback_metrics["train_loss"] < 1.0 def test_train_val_loop_only(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() model.validation_step = None model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False) # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.callback_metrics["train_loss"] < 1.0 def test_dm_checkpoint_save_and_load(tmpdir): class CustomBoringModel(BoringModel): def validation_step(self, batch, batch_idx): out = super().validation_step(batch, batch_idx) self.log("early_stop_on", out["x"]) return out class CustomBoringDataModule(BoringDataModule): def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint[self.__class__.__name__] = self.__class__.__name__ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.checkpoint_state = checkpoint.get(self.__class__.__name__) reset_seed() dm = CustomBoringDataModule() model = CustomBoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=1, enable_model_summary=False, callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on")], ) # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] checkpoint = torch.load(checkpoint_path) assert dm.__class__.__name__ in checkpoint assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ for trainer_fn in TrainerFn: trainer.state.fn = trainer_fn with mock.patch.object(dm, "on_load_checkpoint") as dm_mock: trainer._restore_modules_and_callbacks(checkpoint_path) dm_mock.assert_called_once() def test_full_loop(tmpdir): reset_seed() dm = ClassifDataModule() model = ClassificationModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False, deterministic=True) # fit model trainer.fit(model, dm) assert trainer.state.finished, f"Training failed with {trainer.state}" assert dm.trainer is not None # validate result = trainer.validate(model, dm) assert dm.trainer is not None assert result[0]["val_acc"] > 0.7 # test result = trainer.test(model, dm) assert dm.trainer is not None assert result[0]["test_acc"] > 0.6 @RunIf(min_gpus=1) @mock.patch( "pytorch_lightning.strategies.Strategy.lightning_module", new_callable=PropertyMock, ) def test_dm_apply_batch_transfer_handler(get_module_mock): expected_device = torch.device("cuda", 0) class CustomBatch: def __init__(self, data): self.samples = data[0] self.targets = data[1] class CurrentTestDM(LightningDataModule): rank = 0 transfer_batch_to_device_hook_rank = None on_before_batch_transfer_hook_rank = None on_after_batch_transfer_hook_rank = None def on_before_batch_transfer(self, batch, dataloader_idx): assert dataloader_idx == 0 self.on_before_batch_transfer_hook_rank = self.rank self.rank += 1 batch.samples += 1 return batch def on_after_batch_transfer(self, batch, dataloader_idx): assert dataloader_idx == 0 assert batch.samples.device == batch.targets.device == expected_device self.on_after_batch_transfer_hook_rank = self.rank self.rank += 1 batch.targets *= 2 return batch def transfer_batch_to_device(self, batch, device, dataloader_idx): assert dataloader_idx == 0 self.transfer_batch_to_device_hook_rank = self.rank self.rank += 1 batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) return batch dm = CurrentTestDM() model = BoringModel() batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer(gpus=1) # running .fit() would require us to implement custom data loaders, we mock the model reference instead get_module_mock.return_value = model if is_overridden("transfer_batch_to_device", dm): model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_before_batch_transfer = dm.on_before_batch_transfer model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_after_batch_transfer = dm.on_after_batch_transfer batch_gpu = trainer.strategy.batch_to_device(batch, expected_device) assert dm.on_before_batch_transfer_hook_rank == 0 assert dm.transfer_batch_to_device_hook_rank == 1 assert dm.on_after_batch_transfer_hook_rank == 2 assert batch_gpu.samples.device == batch_gpu.targets.device == expected_device assert torch.allclose(batch_gpu.samples.cpu(), torch.ones(5, 32)) assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) def test_dm_reload_dataloaders_every_n_epochs(tmpdir): """Test datamodule, where trainer argument reload_dataloaders_every_n_epochs is set to a non negative integer.""" class CustomBoringDataModule(BoringDataModule): def __init__(self): super().__init__() self._epochs_called_for = [] def train_dataloader(self): assert self.trainer.current_epoch not in self._epochs_called_for self._epochs_called_for.append(self.trainer.current_epoch) return super().train_dataloader() dm = CustomBoringDataModule() model = BoringModel() model.validation_step = None model.validation_step_end = None model.validation_epoch_end = None model.test_step = None model.test_step_end = None model.test_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, limit_train_batches=2, reload_dataloaders_every_n_epochs=2) trainer.fit(model, dm) class DummyDS(torch.utils.data.Dataset): def __getitem__(self, index): return 1 def __len__(self): return 100 class DummyIDS(torch.utils.data.IterableDataset): def __iter__(self): yield 1 @pytest.mark.parametrize("iterable", (False, True)) def test_dm_init_from_datasets_dataloaders(iterable): ds = DummyIDS if iterable else DummyDS train_ds = ds() dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True) with pytest.raises(NotImplementedError): _ = dm.val_dataloader() with pytest.raises(NotImplementedError): _ = dm.test_dataloader() train_ds_sequence = [ds(), ds()] dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() dl_mock.assert_has_calls( [ call(train_ds_sequence[0], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), ] ) with pytest.raises(NotImplementedError): _ = dm.val_dataloader() with pytest.raises(NotImplementedError): _ = dm.test_dataloader() valid_ds = ds() test_ds = ds() dm = LightningDataModule.from_datasets(val_dataset=valid_ds, test_dataset=test_ds, batch_size=2, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.val_dataloader() dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) dm.test_dataloader() dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) with pytest.raises(NotImplementedError): _ = dm.train_dataloader() valid_dss = [ds(), ds()] test_dss = [ds(), ds()] dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.val_dataloader() dm.test_dataloader() dl_mock.assert_has_calls( [ call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), ] ) # all args class DataModuleWithHparams_0(LightningDataModule): def __init__(self, arg0, arg1, kwarg0=None): super().__init__() self.save_hyperparameters() # single arg class DataModuleWithHparams_1(LightningDataModule): def __init__(self, arg0, *args, **kwargs): super().__init__() self.save_hyperparameters(arg0) def test_hyperparameters_saving(): data = DataModuleWithHparams_0(10, "foo", kwarg0="bar") assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"}) data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar") assert data.hparams == AttributeDict({"hello": "world"}) data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar") assert data.hparams == AttributeDict({"hello": "world"}) if _OMEGACONF_AVAILABLE: data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar") assert data.hparams == OmegaConf.create({"hello": "world"}) def test_define_as_dataclass(): class BoringDataModule(LightningDataModule): def __init__(self, foo=None): super().__init__() # makes sure that no functionality is broken and the user can still manually make # super().__init__ call with parameters # also tests all the dataclass features that can be enabled without breaking anything @dataclass(init=True, repr=True, eq=True, order=True, unsafe_hash=True, frozen=False) class BoringDataModule1(BoringDataModule): batch_size: int foo: int = 2 def __post_init__(self): super().__init__(foo=self.foo) # asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e. # __repr__, __eq__, __lt__, __le__, etc. assert BoringDataModule1(batch_size=64).foo == 2 assert BoringDataModule1(batch_size=32) assert hasattr(BoringDataModule1, "__repr__") assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32) # asserts inherent calling of super().__init__ in case user doesn't make the call @dataclass class BoringDataModule2(LightningDataModule): batch_size: int # asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e. # __init__, __repr__, __eq__, __lt__, __le__, etc. assert BoringDataModule2(batch_size=32) assert hasattr(BoringDataModule2, "__repr__") assert BoringDataModule2(batch_size=32).prepare_data() is None assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32) def test_inconsistent_prepare_data_per_node(tmpdir): with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."): model = BoringModel() dm = BoringDataModule() with pytest.deprecated_call(match="prepare_data_per_node` with the trainer flag is deprecated"): trainer = Trainer(prepare_data_per_node=False) trainer.model = model trainer.datamodule = dm trainer._data_connector.prepare_data()