This commit is contained in:
William Falcon 2020-10-07 13:46:27 -04:00 committed by GitHub
parent 27f536b2ce
commit 6044cf9003
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 3 deletions

View File

@ -235,6 +235,9 @@ class TrainerDataLoadingMixin(ABC):
loader_name = f'{mode}_dataloader'
dataloaders = self.request_dataloader(getattr(model, loader_name))
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
# when overfitting use the training loader as val and test
# duplicate it the numb of times needed to match the train loaders
if self.overfit_batches > 0:
@ -242,9 +245,6 @@ class TrainerDataLoadingMixin(ABC):
train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader'))
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
for loader_i in range(len(dataloaders)):

View File

@ -1,4 +1,5 @@
import torch
import pytest
from tests.base.boring_model import BoringModel, RandomDataset
from pytorch_lightning import Trainer
@ -33,3 +34,21 @@ def test_overfit_multiple_val_loaders(tmpdir):
)
trainer.fit(model)
@pytest.mark.parametrize('overfit', [1, 2, 0.1, 0.25, 1.0])
def test_overfit_basic(tmpdir, overfit):
"""
Tests that only training_step can be used
"""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
overfit_batches=overfit,
weights_summary=None,
)
trainer.fit(model)