# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from unittest.mock import Mock, call, patch import lightning.pytorch import numpy import pytest import torch from lightning.fabric.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset from lightning.pytorch import Callback, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import ( BoringModel, RandomDataset, RandomIterableDataset, RandomIterableDatasetWithLen, ) from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.trainer.connectors.data_connector import _request_dataloader from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning_utilities.test.warning import no_warning_call from torch.utils.data import RandomSampler from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import SequentialSampler from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf class MultiValDataLoaderBoringModel(BoringModel): def val_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64), batch_size=8)] def validation_step(self, batch, batch_idx, dataloader_idx): return super().validation_step(batch, batch_idx) class MultiTestDataLoaderBoringModel(BoringModel): def test_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64), batch_size=8)] def test_step(self, batch, batch_idx, dataloader_idx): return super().test_step(batch, batch_idx) class MultiEvalDataLoaderModel(MultiValDataLoaderBoringModel, MultiTestDataLoaderBoringModel): pass def test_fit_train_loader_only(tmpdir): model = BoringModel() train_dataloader = model.train_dataloader() model.train_dataloader = None model.val_dataloader = None model.test_dataloader = None model.validation_step = None model.test_step = None trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, train_dataloaders=train_dataloader) def test_fit_val_loader_only(tmpdir): model = BoringModel() train_dataloader = model.train_dataloader() val_dataloader = model.val_dataloader() model.train_dataloader = None model.val_dataloader = None model.test_dataloader = None model.test_step = None trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) @pytest.mark.parametrize("dataloader_options", [{"val_check_interval": 10000}]) def test_dataloader_config_errors_runtime(tmpdir, dataloader_options): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, **dataloader_options) with pytest.raises(ValueError, match="less than or equal to the number of the training batches"): trainer.fit(model) @pytest.mark.parametrize( "dataloader_options", [ {"limit_train_batches": -0.1}, {"limit_train_batches": 1.2}, {"limit_val_batches": -0.1}, {"limit_val_batches": 1.2}, {"limit_test_batches": -0.1}, {"limit_test_batches": 1.2}, {"val_check_interval": -0.1}, {"val_check_interval": 1.2}, {"overfit_batches": -0.1}, {"overfit_batches": 1.2}, ], ) def test_dataloader_config_errors_init(tmpdir, dataloader_options): with pytest.raises(MisconfigurationException, match="passed invalid value"): Trainer(default_root_dir=tmpdir, max_epochs=1, **dataloader_options) def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" model = MultiValDataLoaderBoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.3, limit_train_batches=1.0) trainer.fit(model) # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, "Multiple val_dataloaders not initiated properly" @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) def test_multiple_eval_dataloader(tmpdir, ckpt_path): """Verify multiple evaluation dataloaders.""" model = MultiEvalDataLoaderModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=10, limit_train_batches=100) trainer.fit(model) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path trainer.validate(ckpt_path=ckpt_path, verbose=False) # verify there are 2 loaders assert len(trainer.val_dataloaders) == 2 trainer.test(ckpt_path=ckpt_path, verbose=False) assert len(trainer.test_dataloaders) == 2 def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit.""" # only train passed to fit model = BoringModel() train_loader = model.train_dataloader() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) fit_options = {"train_dataloaders": train_loader} trainer.fit(model, **fit_options) assert trainer.num_training_batches == 2 assert trainer.train_dataloader == train_loader @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("n", [1, 2]) def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): """Verify that dataloaders can be passed.""" train_dataloaders = DataLoader(RandomDataset(32, 64)) if n == 1: model = BoringModel() eval_dataloaders = DataLoader(RandomDataset(32, 64)) else: model = MultiEvalDataLoaderModel() eval_dataloaders = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] # multiple val dataloaders passed to fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) trainer.fit(model, train_dataloaders=train_dataloaders, val_dataloaders=eval_dataloaders) if n > 1: assert len(trainer.val_dataloaders) == n else: assert isinstance(trainer.val_dataloaders, DataLoader) if ckpt_path == "specific": ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(dataloaders=eval_dataloaders, ckpt_path=ckpt_path) if n > 1: assert len(trainer.test_dataloaders) == n else: assert isinstance(trainer.test_dataloaders, DataLoader) trainer.validate(dataloaders=eval_dataloaders, ckpt_path=ckpt_path) if n > 1: assert len(trainer.val_dataloaders) == n else: assert isinstance(trainer.val_dataloaders, DataLoader) class DummyModel(BoringModel): def training_step(self, batch, batch_idx): self.log("loss", self.global_step) return super().training_step(batch, batch_idx) def on_validation_epoch_end(self): self.log("val_log", self.current_epoch) class Counter(Callback): def __init__(self): super().__init__() self.train_epoch_count = 0 self.val_epoch_count = 0 self.test_epoch_count = 0 self.train_batches_seen = 0 self.val_batches_seen = 0 self.test_batches_seen = 0 def on_train_batch_start(self, *_): self.train_batches_seen += 1 def on_train_epoch_start(self, *_): self.train_epoch_count += 1 def on_validation_batch_start(self, *_): self.val_batches_seen += 1 def on_test_batch_start(self, *_): self.test_batches_seen += 1 def on_validation_epoch_start(self, *_): self.val_epoch_count += 1 def on_test_epoch_start(self, *_): self.test_epoch_count += 1 def test_inf_dataloaders_with_limit_percent_batches(tmpdir): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent.""" epoch_cb = Counter() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, max_epochs=1, callbacks=[epoch_cb], limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, ) model = DummyModel() batch_size = 8 train_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) val_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) test_dl = DataLoader(dataset=RandomIterableDataset(32, 128), batch_size=batch_size) num_batches = 128 / batch_size for dl in (train_dl, val_dl, test_dl): if has_len_all_ranks(dl, trainer.strategy): assert len(dl) == num_batches else: assert sum(1 for _ in dl) == num_batches trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) assert trainer.num_training_batches == float("inf") assert epoch_cb.train_epoch_count == 1 assert trainer.num_val_batches[0] == float("inf") assert epoch_cb.val_epoch_count == 1 trainer.test(model, dataloaders=test_dl) assert trainer.num_test_batches[0] == float("inf") assert epoch_cb.test_epoch_count == 1 @pytest.mark.parametrize( ("dataset", "limit_train_batches"), [ (RandomDataset(32, 128), 10), (RandomIterableDataset(32, 128), 10), (RandomIterableDatasetWithLen(32, 128), 10), ], ) def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batches): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number.""" epoch_cb = Counter() max_epochs = 2 trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, max_epochs=max_epochs, callbacks=[epoch_cb], limit_train_batches=limit_train_batches, ) model = DummyModel() batch_size = 8 train_dl = DataLoader(dataset=dataset, batch_size=batch_size) val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) assert trainer.num_training_batches == limit_train_batches assert epoch_cb.train_epoch_count == max_epochs assert epoch_cb.train_batches_seen == limit_train_batches * max_epochs @pytest.mark.parametrize( "dataset", [ RandomDataset(32, 128), RandomIterableDataset(32, 128), RandomIterableDatasetWithLen(32, 128), ], ) def test_dataloaders_with_limit_val_batches(tmpdir, dataset): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number.""" epoch_cb = Counter() callbacks = [epoch_cb] enable_checkpointing = False max_epochs = 2 limit_val_batches = 10 trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, max_epochs=max_epochs, callbacks=callbacks, limit_val_batches=limit_val_batches, enable_checkpointing=enable_checkpointing, ) model = DummyModel() batch_size = 8 train_dl = DataLoader(dataset=dataset, batch_size=batch_size) val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.val_epoch_count == max_epochs assert epoch_cb.val_batches_seen == limit_val_batches * max_epochs @pytest.mark.parametrize( "dataset", [ RandomDataset(32, 128), RandomIterableDataset(32, 128), RandomIterableDatasetWithLen(32, 128), ], ) def test_datasets_dataloaders_with_limit_num_batches(tmpdir, dataset): """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number.""" epoch_cb = Counter() max_epochs = 2 limit_batches = 10 trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, max_epochs=max_epochs, callbacks=[epoch_cb], limit_train_batches=limit_batches, limit_val_batches=limit_batches, limit_test_batches=limit_batches, ) model = DummyModel() batch_size = 8 train_dl = DataLoader(dataset=dataset, batch_size=batch_size) val_dl = DataLoader(dataset=dataset, batch_size=batch_size) test_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) assert trainer.num_training_batches == limit_batches assert trainer.num_val_batches[0] == limit_batches assert epoch_cb.train_epoch_count == max_epochs assert epoch_cb.train_batches_seen == limit_batches * max_epochs assert epoch_cb.val_epoch_count == max_epochs assert epoch_cb.val_batches_seen == limit_batches * max_epochs trainer.test(model, dataloaders=test_dl) assert trainer.num_test_batches[0] == limit_batches assert epoch_cb.test_epoch_count == 1 @pytest.mark.parametrize( ("limit_train_batches", "limit_val_batches", "limit_test_batches"), [(1.0, 1.0, 1.0), (0.2, 0.4, 0.4)], ) def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for train, val & test dataloaders passed with batch limit in percent.""" model = MultiEvalDataLoaderModel() # train, multiple val and multiple test passed with percent_check trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) trainer.fit(model) expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches) expected_val_batches = [int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders] assert trainer.num_training_batches == expected_train_batches assert trainer.num_val_batches == expected_val_batches trainer.test(model) expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders] assert trainer.num_test_batches == expected_test_batches @pytest.mark.parametrize(("limit_train_batches", "limit_val_batches", "limit_test_batches"), [(1, 2, 3), (1, 2, 1e50)]) def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for train, val & test dataloaders passed with batch limit as number.""" model = MultiEvalDataLoaderModel() # train, multiple val and multiple test passed with percent_check trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, num_sanity_val_steps=0, ) with patch.object( trainer.fit_loop.epoch_loop.val_loop, "_evaluation_step", wraps=trainer.fit_loop.epoch_loop.val_loop._evaluation_step, ) as mocked: trainer.fit(model) assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders) assert mocked.call_count == limit_val_batches * len(trainer.val_dataloaders) with patch.object( trainer.test_loop, "_evaluation_step", wraps=trainer.test_loop._evaluation_step, ) as mocked: trainer.test(model) test_dataloader_lengths = [len(x) for x in model.test_dataloader()] if limit_test_batches > 1e10: # when the limit is greater than the number of test batches it should be the num in loaders assert trainer.num_test_batches == test_dataloader_lengths assert mocked.call_count == sum(test_dataloader_lengths) else: assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders) assert mocked.call_count == limit_test_batches * len(trainer.test_dataloaders) @pytest.mark.parametrize("fast_dev_run", [True, 1, 3, -1]) def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): """Verify num_batches for train, val & test dataloaders passed with fast_dev_run.""" model = MultiEvalDataLoaderModel() trainer_options = {"default_root_dir": tmpdir, "max_epochs": 2, "fast_dev_run": fast_dev_run} if fast_dev_run == -1: with pytest.raises(MisconfigurationException, match="should be >= 0"): Trainer(**trainer_options) else: trainer = Trainer(**trainer_options) # fast_dev_run is set to True when it is 1 if fast_dev_run == 1: fast_dev_run = True assert trainer.fast_dev_run is fast_dev_run if fast_dev_run is True: fast_dev_run = 1 assert trainer.limit_train_batches == fast_dev_run assert trainer.limit_val_batches == fast_dev_run assert trainer.limit_test_batches == fast_dev_run assert trainer.num_sanity_val_steps == 0 assert trainer.max_epochs == 1 trainer.fit(model) assert trainer.enable_validation assert trainer.num_training_batches == fast_dev_run assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders) trainer.test(model) assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders) @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) def test_mixing_of_dataloader_options(tmpdir, ckpt_path): """Verify that dataloaders can be passed to fit.""" model = BoringModel() eval_dataloader = DataLoader(RandomDataset(32, 64)) trainer_options = { "default_root_dir": tmpdir, "max_epochs": 1, "limit_val_batches": 0.1, "limit_train_batches": 0.2, } # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=eval_dataloader) assert trainer.val_dataloaders == eval_dataloader ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path trainer.test(dataloaders=eval_dataloader, ckpt_path=ckpt_path) assert trainer.test_dataloaders == eval_dataloader def test_warning_on_zero_len_dataloader(): """Test that a warning is raised if a zero-length dataloader is defined.""" model = BoringModel() trainer = Trainer() trainer.strategy.connect(model) train_dataloader = DataLoader(RandomDataset(32, 0)) val_dataloader = DataLoader(RandomDataset(32, 0)) trainer._data_connector.attach_data(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) with pytest.warns(UserWarning, match="Total length of `CombinedLoader` across ranks is zero"): trainer.fit_loop.setup_data() assert trainer.num_training_batches == 0 trainer.state.fn = "validate" with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero"): trainer.validate_loop.setup_data() assert trainer.num_val_batches == [0] @RunIf(skip_windows=True) @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("stage", ["train", "test", "val"]) @patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4) def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """Test that error is raised if dataloader with only a few workers is used.""" model = BoringModel() train_dl = model.train_dataloader() train_dl.num_workers = 0 val_dl = model.val_dataloader() val_dl.num_workers = 0 trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"): if stage == "test": if ckpt_path in ("specific", "best"): trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path trainer.test(model, dataloaders=train_dl, ckpt_path=ckpt_path) else: trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) @RunIf(skip_windows=True) @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("stage", ["train", "test", "val"]) @patch("lightning.fabric.utilities.data._num_cpus_available", return_value=4) def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): """Test that a warning is emitted if the dataloader only has a few workers.""" class CustomModel(MultiEvalDataLoaderModel): def training_step(self, batch, batch_idx): return super().training_step(batch["a_b"][0], batch_idx) model = CustomModel() val_dl = DataLoader(RandomDataset(32, 64)) val_dl.num_workers = 0 train_dl = DataLoader(RandomDataset(32, 64)) train_dl.num_workers = 0 train_multi_dl = {"a_b": [train_dl, train_dl], "c_d_e": [train_dl, train_dl, train_dl]} val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) with pytest.warns( UserWarning, match=f"The '{stage}_dataloader' does not have many workers", ): if stage == "test": if ckpt_path in ("specific", "best"): trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl) ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path trainer.test(model, dataloaders=test_multi_dl, ckpt_path=ckpt_path) else: trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl) class NumpyRandomDataset(Dataset): # this dataset uses numpy instead of torch to produce random numbers size = 16 def __getitem__(self, index): return numpy.random.randint(0, 100, 3) def __len__(self): return self.size def _user_worker_init_fn(_): pass def test_auto_add_worker_init_fn(): """Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used.""" dataset = Mock() dataloader = DataLoader(dataset) # without pl.seed_everything() _auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is None # with forcefully avoiding it seed_everything(0, workers=False) _auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is None # when user already has a worker_init_fn user_function = _user_worker_init_fn dataloader.worker_init_fn = user_function _auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is user_function dataloader.worker_init_fn = None # main use case seed_everything(0, workers=True) _auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is not None class MultiProcessModel(BoringModel): def __init__(self): super().__init__() self.batches_seen = [] def training_step(self, batch, batch_idx): self.batches_seen.append(batch) def on_train_epoch_end(self): world_size = 2 num_samples = NumpyRandomDataset.size all_batches = torch.cat(self.batches_seen) all_batches = self.all_gather(all_batches) assert all_batches.shape[0] == world_size all_batches = all_batches.view(-1, 3) assert len(torch.unique(all_batches, dim=0)) == num_samples @RunIf(min_cuda_gpus=2) def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): """Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training.""" dataset = NumpyRandomDataset() num_workers = 2 batch_size = 2 dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) seed_everything(0, workers=True) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp_spawn") model = MultiProcessModel() model.val_dataloader = None trainer.fit(model, train_dataloaders=dataloader) def test_warning_with_small_dataloader_and_logging_interval(tmpdir): """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval.""" model = BoringModel() dataloader = DataLoader(RandomDataset(32, length=10)) model.train_dataloader = lambda: dataloader with pytest.warns(UserWarning, match=r"The number of training batches \(10\) is smaller than the logging interval"): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=11, logger=CSVLogger(tmpdir)) trainer.fit(model) with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1, logger=CSVLogger(".") ) trainer.fit(model) with no_warning_call(UserWarning, match="The number of training batches"): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, log_every_n_steps=2) trainer.fit(model) def test_warning_with_iterable_dataset_and_len(tmpdir): """Tests that a warning message is shown when an IterableDataset defines `__len__`.""" model = BoringModel() original_dataset = model.train_dataloader().dataset class IterableWithoutLen(IterableDataset): def __iter__(self): return iter(original_dataset) class IterableWithLen(IterableWithoutLen): def __len__(self): return len(original_dataset) # with __len__ defined trainer = Trainer(default_root_dir=tmpdir, max_steps=3) dataloader = DataLoader(IterableWithLen(), batch_size=16) assert has_len_all_ranks(dataloader, trainer.strategy) assert has_iterable_dataset(dataloader) with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."): trainer.validate(model, dataloaders=[dataloader]) with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."): trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."): trainer.test(model, dataloaders=[dataloader]) with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."): trainer.predict(model, dataloaders=[dataloader]) # without __len__ defined trainer = Trainer(default_root_dir=tmpdir, max_steps=3) dataloader = DataLoader(IterableWithoutLen(), batch_size=16) assert not has_len_all_ranks(dataloader, trainer.strategy) assert has_iterable_dataset(dataloader) trainer.validate(model, dataloaders=dataloader) trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=[dataloader]) trainer.test(model, dataloaders=dataloader) trainer.predict(model, dataloaders=dataloader) @pytest.mark.parametrize("yield_at_all", [False, True]) def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all): """Test that the training loop skips execution if the iterator is empty from the start.""" class TestDataset(IterableDataset): def __init__(self, gen): self.gen = gen def __iter__(self): return iter(self.gen()) class TestModel(BoringModel): def gen(self): # produce data in epoch 0, no data otherwise if yield_at_all and self.current_epoch == 0: yield torch.rand(32) yield torch.rand(32) yield torch.rand(32) model = TestModel() train_dataloader = DataLoader(TestDataset(model.gen), batch_size=2) trainer = Trainer( default_root_dir=os.getcwd(), max_epochs=2, enable_model_summary=False, ) trainer.fit(model, train_dataloaders=train_dataloader) assert trainer.global_step == 2 * yield_at_all # even though the generator might not yield any data, the fit_loop still advances so the # current epoch gets increased assert trainer.current_epoch == 2 class DistribSamplerCallback(Callback): def __init__(self, expected_seeds=(0, 0, 0)): self.expected_seed = expected_seeds def on_train_start(self, trainer, pl_module): train_sampler = trainer.train_dataloader.sampler assert isinstance(train_sampler, DistributedSampler) assert train_sampler.shuffle assert train_sampler.seed == self.expected_seed[0] def on_validation_start(self, trainer, pl_module): val_sampler = trainer.val_dataloaders.sampler assert isinstance(val_sampler, DistributedSampler) assert not val_sampler.shuffle assert val_sampler.seed == self.expected_seed[1] def on_test_start(self, trainer, pl_module): test_sampler = trainer.test_dataloaders.sampler assert isinstance(test_sampler, DistributedSampler) assert not test_sampler.shuffle assert test_sampler.seed == self.expected_seed[2] @RunIf(min_cuda_gpus=2, skip_windows=True) def test_dataloader_distributed_sampler(tmpdir): """Test DistributedSampler and it's arguments for DDP backend.""" seed_everything(123) model = BoringModel() trainer = Trainer( accelerator="gpu", devices=[0, 1], num_nodes=1, strategy="ddp_spawn", default_root_dir=tmpdir, max_steps=1, callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))], ) trainer.fit(model) trainer.test(model) class TestModelUniqueDDPSampling(BoringModel): def __init__(self): super().__init__() self.seen_samples = [] def training_step(self, batch): self.seen_samples.extend(batch.tolist()) def on_train_end(self): seen_samples = self.all_gather(self.seen_samples) # The samples should be unique across all processes assert set(torch.cat(seen_samples).view(-1).tolist()) == set(range(32)) @RunIf(standalone=True) def test_distributed_sampler_without_global_seed(tmpdir): """Test that the samples are non-overlapping in DDP when shuffling is enabled and no global seed is set.""" # This test must run without a global seed set (e.g. through `seed_everything`), to ensure that each process # starts with a different initial state. assert "PL_GLOBAL_SEED" not in os.environ train_dataloader = DataLoader(range(32), shuffle=True, batch_size=4) trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=False, logger=False, enable_progress_bar=False, accelerator="cpu", devices=2, strategy="ddp", max_epochs=1, ) trainer.fit(TestModelUniqueDDPSampling(), train_dataloader) class ModelWithDataLoaderDistributedSampler(BoringModel): def train_dataloader(self): dataloader = super().train_dataloader() dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True, seed=11) return DataLoader(dataloader.dataset, batch_size=32, drop_last=False, sampler=dist_sampler, shuffle=False) @RunIf(min_cuda_gpus=2, skip_windows=True) def test_dataloader_distributed_sampler_already_attached(tmpdir): """Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader.""" seed_everything(123) model = ModelWithDataLoaderDistributedSampler() trainer = Trainer( accelerator="gpu", devices=[0, 1], num_nodes=1, strategy="ddp_spawn", default_root_dir=tmpdir, max_steps=100, callbacks=[DistribSamplerCallback(expected_seeds=(11, 123, 0))], use_distributed_sampler=True, ) trainer.fit(model) assert trainer.state.finished, "DDP Training failed" @pytest.mark.parametrize( ("mode", "num_training_batches"), [("min_size", 16), ("max_size_cycle", 64), ("max_size", 64), ("sequential", 64 + 16 * 4)], ) def test_fit_multiple_train_loaders(tmpdir, mode, num_training_batches): """Integration test for multiple train iterables.""" class CustomBoringModel(BoringModel): def train_dataloader(self): loaders_a_b = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 16))] loaders_c_d_e = [ DataLoader(RandomDataset(32, 16)), DataLoader(RandomDataset(32, 16)), DataLoader(RandomDataset(32, 16)), ] # Dict[str, List[DataLoader]] loaders = {"a_b": loaders_a_b, "c_d_e": loaders_c_d_e} return CombinedLoader(loaders, mode) def training_step(self, batch, batch_idx): assert len(batch) == 2 assert len(batch["a_b"]) == 2 assert len(batch["c_d_e"]) == 3 return super().training_step(batch["a_b"][0], batch_idx) model = CustomBoringModel() trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) if mode == "sequential": with pytest.raises(ValueError, match="FitLoop` does not support"): trainer.fit(model) else: trainer.fit(model) # verify the num_training_batches according to the mode assert num_training_batches == trainer.num_training_batches @pytest.mark.parametrize("check_interval", [50, 1.0]) @pytest.mark.parametrize("dataloader_wrapper", [CustomNotImplementedErrorDataloader, CustomInfDataloader]) def test_train_dataloader_not_implemented_error(tmpdir, check_interval, dataloader_wrapper): """Test not_implemented_error train data loader (e.g. IterableDataset)""" class CustomBoringModel(BoringModel): def train_dataloader(self): return dataloader_wrapper(DataLoader(RandomDataset(32, 64))) def val_dataloader(self): return dataloader_wrapper(DataLoader(RandomDataset(32, 64))) model = CustomBoringModel() trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval) trainer.fit(model) # verify training completed @pytest.mark.parametrize( "stage", [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] ) @pytest.mark.parametrize("dataloader_wrapper", [CustomNotImplementedErrorDataloader, CustomInfDataloader]) def test_inf_dataloader_raise_error_with_partial_batch_limits(tmpdir, stage, dataloader_wrapper): """Test limit_batch error with inf dataloader (e.g. IterableDataset)""" model = BoringModel() setattr( model, f"{stage.dataloader_prefix}_dataloader", lambda: dataloader_wrapper(DataLoader(RandomDataset(32, 64))) ) trainer_kwargs = {"default_root_dir": tmpdir, "max_epochs": 1, f"limit_{stage.dataloader_prefix}_batches": 0.5} trainer = Trainer(**trainer_kwargs) trainer_fn = "fit" if stage == RunningStage.TRAINING else stage.value with pytest.raises(MisconfigurationException, match=r"IterableDataset`.*limit_.*_batches\)`.*`1.0` or an int"): getattr(trainer, trainer_fn)(model) def test_dataloaders_load_only_once(tmpdir): model = BoringModel() tracker = Mock() model.train_dataloader = Mock(wraps=model.train_dataloader) model.val_dataloader = Mock(wraps=model.val_dataloader) model.test_dataloader = Mock(wraps=model.test_dataloader) tracker.attach_mock(model.train_dataloader, "train_dataloader") tracker.attach_mock(model.val_dataloader, "val_dataloader") tracker.attach_mock(model.test_dataloader, "test_dataloader") trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, max_epochs=3) trainer.fit(model) model.train_dataloader.assert_called_once() model.val_dataloader.assert_called_once() model.test_dataloader.assert_not_called() assert tracker.mock_calls == [call.val_dataloader(), call.train_dataloader()] def test_dataloaders_load_only_once_no_sanity_check(tmpdir): model = BoringModel() # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3 ) tracker = Mock() model.train_dataloader = Mock(wraps=model.train_dataloader) model.val_dataloader = Mock(wraps=model.val_dataloader) model.test_dataloader = Mock(wraps=model.test_dataloader) tracker.attach_mock(model.train_dataloader, "train_dataloader") tracker.attach_mock(model.val_dataloader, "val_dataloader") tracker.attach_mock(model.test_dataloader, "test_dataloader") trainer.fit(model) # verify the sequence expected_sequence = [call.train_dataloader(), call.val_dataloader()] assert tracker.mock_calls == expected_sequence @pytest.mark.parametrize( ( "num_sanity_val_steps", "check_val_every_n_epoch", "reload_dataloaders_every_n_epochs", "train_reload_epochs_expect", "val_reload_epochs_expect", "val_step_epochs_expect", ), [ # general case where sanity check reloads the dataloaders for validation on current_epoch=0 (0, 1, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), (1, 1, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), # case where check_val_every_n_epoch < reload_dataloaders_every_n_epochs so expected val_reload_epoch # and val_step_epoch will be different (0, 1, 2, [0, 2, 4, 6, 8], [0, 2, 4, 6, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), (1, 1, 2, [0, 2, 4, 6, 8], [2, 4, 6, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), (0, 3, 4, [0, 4, 8], [2, 8], [2, 5, 8]), (1, 3, 4, [0, 4, 8], [2, 8], [2, 5, 8]), # case where check_val_every_n_epoch > reload_dataloaders_every_n_epochs so expected val_reload_epoch # and val_step_epoch will be same (0, 2, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]), (1, 2, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]), (0, 3, 2, [0, 2, 4, 6, 8], [2, 5, 8], [2, 5, 8]), (1, 3, 2, [0, 2, 4, 6, 8], [2, 5, 8], [2, 5, 8]), (0, 5, 2, [0, 2, 4, 6, 8], [4, 9], [4, 9]), (1, 5, 2, [0, 2, 4, 6, 8], [4, 9], [4, 9]), # case where check_val_every_n_epoch = reload_dataloaders_every_n_epochs so expected val_reload_epoch # and val_step_epoch will be same (0, 2, 2, [0, 2, 4, 6, 8], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]), (1, 2, 2, [0, 2, 4, 6, 8], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]), ], ) def test_dataloaders_load_every_n_epochs_infrequent_val( tmpdir, num_sanity_val_steps, check_val_every_n_epoch, reload_dataloaders_every_n_epochs, train_reload_epochs_expect, val_reload_epochs_expect, val_step_epochs_expect, ): """Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)""" sanity_val_check_epochs, train_reload_epochs, val_reload_epochs = [], [], [] sanity_val_step_epochs, val_step_epochs = [], [] class TestModel(BoringModel): def train_dataloader(self): train_reload_epochs.append(self.current_epoch) return super().train_dataloader() def val_dataloader(self): if self.trainer.sanity_checking: sanity_val_check_epochs.append(self.current_epoch) else: val_reload_epochs.append(self.current_epoch) return super().val_dataloader() def validation_step(self, *args, **kwargs): if self.trainer.sanity_checking: sanity_val_step_epochs.append(self.current_epoch) else: val_step_epochs.append(self.current_epoch) return super().validation_step(*args, **kwargs) model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1, check_val_every_n_epoch=check_val_every_n_epoch, reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs, max_epochs=10, num_sanity_val_steps=num_sanity_val_steps, ) trainer.fit(model) # Verify epoch of reloads sanity_val_check_epochs_expect = [0] if num_sanity_val_steps else [] assert sanity_val_check_epochs == sanity_val_step_epochs == sanity_val_check_epochs_expect assert train_reload_epochs == train_reload_epochs_expect assert val_reload_epochs == val_reload_epochs_expect assert val_step_epochs == val_step_epochs_expect def test_dataloaders_load_every_n_epochs_frequent_val(tmpdir): """Test dataloader reload behavior when frequently checking validation set (via val_check_interval)""" train_reload_epochs, val_reload_epochs, val_check_epochs = [], [], [] class TestModel(BoringModel): def train_dataloader(self): train_reload_epochs.append(self.current_epoch) return super().train_dataloader() def val_dataloader(self): val_reload_epochs.append(self.current_epoch) return super().val_dataloader() def on_validation_epoch_end(self): val_check_epochs.append(self.current_epoch) model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, val_check_interval=0.3, reload_dataloaders_every_n_epochs=1, max_epochs=3, ) model.test_dataloader = Mock(wraps=model.test_dataloader) trainer.fit(model) trainer.test(model) # Verify epoch of reloads assert train_reload_epochs == [0, 1, 2] assert val_reload_epochs == [0, 1, 2] model.test_dataloader.assert_called_once() # Verify validation happens 3 times per epoch + 1 for sanity check assert val_check_epochs == [0, 0, 0, 0, 1, 1, 1, 2, 2, 2] @pytest.mark.parametrize("n", ["test", -1]) def test_dataloaders_load_every_n_epochs_exception(tmpdir, n): with pytest.raises(MisconfigurationException, match="should be an int >"): Trainer(default_root_dir=tmpdir, reload_dataloaders_every_n_epochs=n) def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir): class TestModel(BoringModel): def validation_step(self, batch, batch_idx): self.log("dummy_val", 5.0) return super().validation_step(batch, batch_idx) model = TestModel() # This callback tests that the evaluation metrics are available by the time we run checkpointing checkpoint_callback = ModelCheckpoint(monitor="dummy_val", save_top_k=1) # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, reload_dataloaders_every_n_epochs=1, max_epochs=3, callbacks=[checkpoint_callback], ) tracker = Mock() model.train_dataloader = Mock(wraps=model.train_dataloader) model.val_dataloader = Mock(wraps=model.val_dataloader) model.test_dataloader = Mock(wraps=model.test_dataloader) tracker.attach_mock(model.train_dataloader, "train_dataloader") tracker.attach_mock(model.val_dataloader, "val_dataloader") tracker.attach_mock(model.test_dataloader, "test_dataloader") trainer.fit(model) trainer.test(model) expected_calls = [ call.train_dataloader(), call.val_dataloader(), call.train_dataloader(), call.val_dataloader(), call.train_dataloader(), call.val_dataloader(), call.test_dataloader(), ] assert tracker.mock_calls == expected_calls @pytest.mark.parametrize("sanity_check", [False, True]) def test_dataloaders_load_only_once_passed_loaders(tmp_path, monkeypatch, sanity_check): model = BoringModel() train_dataloader = model.train_dataloader() val_dataloader = model.val_dataloader() test_dataloader = model.test_dataloader() # delete dataloader methods on the model model.train_dataloader = None model.val_dataloader = None model.test_dataloader = None stages = [] original_request_dataloader = _request_dataloader def side_effect_request_dataloader(ds): stages.append(trainer.state.stage) return original_request_dataloader(ds) request_dataloader_mock = Mock(wraps=side_effect_request_dataloader) monkeypatch.setattr(lightning.pytorch.loops.fit_loop, "_request_dataloader", request_dataloader_mock) monkeypatch.setattr(lightning.pytorch.loops.evaluation_loop, "_request_dataloader", request_dataloader_mock) trainer = Trainer( default_root_dir=tmp_path, limit_train_batches=0.3, limit_val_batches=0.3, max_epochs=3, num_sanity_val_steps=1 if sanity_check else 0, ) trainer.fit(model, train_dataloader, val_dataloader) assert request_dataloader_mock.call_count == 2 request_dataloader_mock.reset_mock() trainer.test(model, dataloaders=test_dataloader) assert request_dataloader_mock.call_count == 1 expected = ["sanity_check", "train", "test"] if sanity_check else ["train", "validate", "test"] assert stages == expected def test_dataloaders_reset_and_attach(tmpdir): """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching the new one.""" # the assertions compare the datasets and not dataloaders since we patch and replace the samplers dataloader_0 = DataLoader(dataset=RandomDataset(32, 64)) dataloader_1 = DataLoader(dataset=RandomDataset(32, 64)) dataloader_2 = DataLoader(dataset=RandomDataset(32, 64)) dataloader_3 = DataLoader(dataset=RandomDataset(32, 64)) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_steps=1, limit_val_batches=1, limit_test_batches=1, limit_predict_batches=1, ) # 1st fit trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1) assert trainer.train_dataloader.dataset is dataloader_0.dataset assert trainer.val_dataloaders.dataset is dataloader_1.dataset # 2nd fit trainer.fit_loop.epoch_loop.max_steps += 1 trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3) assert trainer.train_dataloader.dataset is dataloader_2.dataset assert trainer.val_dataloaders.dataset is dataloader_3.dataset # 1st validate trainer.validate(model, dataloaders=dataloader_0) assert trainer.val_dataloaders.dataset is dataloader_0.dataset # 2nd validate trainer.validate(model, dataloaders=dataloader_1) assert trainer.val_dataloaders.dataset is dataloader_1.dataset # 1st test trainer.test(model, dataloaders=dataloader_0) assert trainer.test_dataloaders.dataset is dataloader_0.dataset # 2nd test trainer.test(model, dataloaders=dataloader_1) assert trainer.test_dataloaders.dataset is dataloader_1.dataset # 1st predict trainer.predict(model, dataloaders=dataloader_0) assert trainer.predict_dataloaders.dataset is dataloader_0.dataset # 2nd predict trainer.predict(model, dataloaders=dataloader_1) assert trainer.predict_dataloaders.dataset is dataloader_1.dataset @pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"]) def test_correct_dataloader_idx_in_hooks(tmpdir, mode): """Check the correct dataloader_idx inside hooks.""" class CustomBoringModel(BoringModel): def __init__(self): super().__init__() self.val_call_count = 0 self.test_call_count = 0 def assert_dataloader_idx_hook(self, dataloader_idx): if self.trainer.training: assert dataloader_idx == 0 elif self.trainer.validating: assert dataloader_idx == (0 if self.val_call_count <= 5 else 1) elif self.trainer.testing: assert dataloader_idx == (0 if self.test_call_count <= 5 else 1) def transfer_batch_to_device(self, batch, device, dataloader_idx): self.assert_dataloader_idx_hook(dataloader_idx) return super().transfer_batch_to_device(batch, device, dataloader_idx) def on_before_batch_transfer(self, batch, dataloader_idx): # incrementing here since this is the first hook called at each step if self.trainer.validating: self.val_call_count += 1 elif self.trainer.testing: self.test_call_count += 1 self.assert_dataloader_idx_hook(dataloader_idx) return super().on_before_batch_transfer(batch, dataloader_idx) def on_after_batch_transfer(self, batch, dataloader_idx): self.assert_dataloader_idx_hook(dataloader_idx) return super().on_after_batch_transfer(batch, dataloader_idx) def training_step(self, batch, batch_idx): return super().training_step(batch["a"], batch_idx) def validation_step(self, batch, batch_idx, dataloader_idx): self.assert_dataloader_idx_hook(dataloader_idx) out = super().validation_step(batch, batch_idx) loss = out.pop("x") out[f"val_loss_{dataloader_idx}"] = loss return out def test_step(self, batch, batch_idx, dataloader_idx): self.assert_dataloader_idx_hook(dataloader_idx) out = super().test_step(batch, batch_idx) loss = out.pop("y") out[f"test_loss_{dataloader_idx}"] = loss return out def predict(self, batch, batch_idx, dataloader_idx): self.assert_dataloader_idx_hook(dataloader_idx) return super().predict(batch, batch_idx, dataloader_idx) def train_dataloader(self): return CombinedLoader( {"a": DataLoader(RandomDataset(32, 64)), "b": DataLoader(RandomDataset(32, 64))}, mode=mode ) def val_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] def test_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] def predict_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] model = CustomBoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=5) trainer.fit(model) trainer.test(model) preds = trainer.predict(model) assert len(preds) == 2 assert all(len(x) == 5 for x in preds) def test_request_dataloader(tmpdir): """This test asserts dataloader can be wrapped.""" class DataLoaderWrapper: def __init__(self, loader): self.loader = loader self._iter = iter(self.loader) def __iter__(self): self._iter = iter(self.loader) return self._iter def __next__(self): return next(self._iter) class TestModel(BoringModel): def __init__(self): super().__init__() self.on_train_batch_start_called = False self.on_val_batch_start_called = False def train_dataloader(self): loader = super().train_dataloader() return DataLoaderWrapper(loader) def on_train_batch_start(self, batch, batch_idx: int) -> None: assert isinstance(self.trainer.train_dataloader, DataLoaderWrapper) self.on_train_batch_start_called = True def val_dataloader(self): loader = super().val_dataloader() return DataLoaderWrapper(loader) def on_validation_batch_start(self, *_): assert isinstance(self.trainer.val_dataloaders, DataLoaderWrapper) self.on_val_batch_start_called = True trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1 ) model = TestModel() trainer.fit(model) trainer.test(model) assert model.on_train_batch_start_called assert model.on_val_batch_start_called @pytest.mark.parametrize("num_loaders", [1, 2]) def test_multiple_dataloaders_with_random_sampler_overfit_batches(num_loaders, tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): assert all(isinstance(dl.sampler, SequentialSampler) for dl in self.trainer.train_dataloader) return super().training_step(batch[0], batch_idx) def _create_dataloader(self): ds = RandomDataset(32, 64) return DataLoader(ds, sampler=RandomSampler(ds)) def train_dataloader(self): return [self._create_dataloader() for _ in range(num_loaders)] validation_step = None trainer = Trainer(default_root_dir=tmpdir, overfit_batches=1.0, max_epochs=1) trainer.fit(TestModel())