parent
27f536b2ce
commit
6044cf9003
|
@ -235,6 +235,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
loader_name = f'{mode}_dataloader'
|
||||
dataloaders = self.request_dataloader(getattr(model, loader_name))
|
||||
|
||||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
|
||||
# when overfitting use the training loader as val and test
|
||||
# duplicate it the numb of times needed to match the train loaders
|
||||
if self.overfit_batches > 0:
|
||||
|
@ -242,9 +245,6 @@ class TrainerDataLoadingMixin(ABC):
|
|||
train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader'))
|
||||
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]
|
||||
|
||||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
|
||||
self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
|
||||
|
||||
for loader_i in range(len(dataloaders)):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import pytest
|
||||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
|
@ -33,3 +34,21 @@ def test_overfit_multiple_val_loaders(tmpdir):
|
|||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('overfit', [1, 2, 0.1, 0.25, 1.0])
|
||||
def test_overfit_basic(tmpdir, overfit):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
overfit_batches=overfit,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue