diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 84f73827e3..a4ca9b3025 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -27,6 +27,7 @@ from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.model_utils import is_overridden +from copy import deepcopy try: @@ -233,14 +234,17 @@ class TrainerDataLoadingMixin(ABC): Returns: Tuple (num_batches, dataloaders) """ - # use the training loader as val and test when overfitting + # always get the loaders first so we can count how many there are loader_name = f'{mode}_dataloader' - if self.overfit_batches > 0: - loader_name = 'train_dataloader' - - # load loaders dataloaders = self.request_dataloader(getattr(model, loader_name)) + # when overfitting use the training loader as val and test + # duplicate it the numb of times needed to match the train loaders + if self.overfit_batches > 0: + num_loaders = len(dataloaders) + train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader')) + dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] + if not isinstance(dataloaders, list): dataloaders = [dataloaders] diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py new file mode 100644 index 0000000000..f9997292ec --- /dev/null +++ b/tests/trainer/flags/test_overfit_batches.py @@ -0,0 +1,35 @@ +import torch +from tests.base.boring_model import BoringModel, RandomDataset +from pytorch_lightning import Trainer + + +def test_overfit_multiple_val_loaders(tmpdir): + """ + Tests that only training_step can be used + """ + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx, dataloader_idx): + output = self.layer(batch[0]) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + pass + + def val_dataloader(self): + dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64)) + dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64)) + return [dl1, dl2] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + overfit_batches=1, + row_log_interval=1, + weights_summary=None, + ) + + trainer.fit(model)