This commit is contained in:
William Falcon 2020-10-04 23:25:02 -04:00 committed by GitHub
parent ab5e9496d0
commit d787208e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 5 deletions

View File

@ -27,6 +27,7 @@ from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.model_utils import is_overridden
from copy import deepcopy
try:
@ -233,14 +234,17 @@ class TrainerDataLoadingMixin(ABC):
Returns:
Tuple (num_batches, dataloaders)
"""
# use the training loader as val and test when overfitting
# always get the loaders first so we can count how many there are
loader_name = f'{mode}_dataloader'
if self.overfit_batches > 0:
loader_name = 'train_dataloader'
# load loaders
dataloaders = self.request_dataloader(getattr(model, loader_name))
# when overfitting use the training loader as val and test
# duplicate it the numb of times needed to match the train loaders
if self.overfit_batches > 0:
num_loaders = len(dataloaders)
train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader'))
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

View File

@ -0,0 +1,35 @@
import torch
from tests.base.boring_model import BoringModel, RandomDataset
from pytorch_lightning import Trainer
def test_overfit_multiple_val_loaders(tmpdir):
"""
Tests that only training_step can be used
"""
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx, dataloader_idx):
output = self.layer(batch[0])
loss = self.loss(batch, output)
return {"x": loss}
def validation_epoch_end(self, outputs) -> None:
pass
def val_dataloader(self):
dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64))
dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64))
return [dl1, dl2]
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
overfit_batches=1,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)