From 6044cf900317ec9542fb1745976c9a96cc70b396 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 7 Oct 2020 13:46:27 -0400 Subject: [PATCH] Fixes #3945 (#3947) --- pytorch_lightning/trainer/data_loading.py | 6 +++--- tests/trainer/flags/test_overfit_batches.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 69b2cf1353..5f9a3b8d54 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -235,6 +235,9 @@ class TrainerDataLoadingMixin(ABC): loader_name = f'{mode}_dataloader' dataloaders = self.request_dataloader(getattr(model, loader_name)) + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + # when overfitting use the training loader as val and test # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: @@ -242,9 +245,6 @@ class TrainerDataLoadingMixin(ABC): train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader')) dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] - if not isinstance(dataloaders, list): - dataloaders = [dataloaders] - self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) for loader_i in range(len(dataloaders)): diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py index f9997292ec..cf172f469f 100644 --- a/tests/trainer/flags/test_overfit_batches.py +++ b/tests/trainer/flags/test_overfit_batches.py @@ -1,4 +1,5 @@ import torch +import pytest from tests.base.boring_model import BoringModel, RandomDataset from pytorch_lightning import Trainer @@ -33,3 +34,21 @@ def test_overfit_multiple_val_loaders(tmpdir): ) trainer.fit(model) + + +@pytest.mark.parametrize('overfit', [1, 2, 0.1, 0.25, 1.0]) +def test_overfit_basic(tmpdir, overfit): + """ + Tests that only training_step can be used + """ + + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + overfit_batches=overfit, + weights_summary=None, + ) + + trainer.fit(model)