parent
ab5e9496d0
commit
d787208e76
|
@ -27,6 +27,7 @@ from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
try:
|
||||
|
@ -233,14 +234,17 @@ class TrainerDataLoadingMixin(ABC):
|
|||
Returns:
|
||||
Tuple (num_batches, dataloaders)
|
||||
"""
|
||||
# use the training loader as val and test when overfitting
|
||||
# always get the loaders first so we can count how many there are
|
||||
loader_name = f'{mode}_dataloader'
|
||||
if self.overfit_batches > 0:
|
||||
loader_name = 'train_dataloader'
|
||||
|
||||
# load loaders
|
||||
dataloaders = self.request_dataloader(getattr(model, loader_name))
|
||||
|
||||
# when overfitting use the training loader as val and test
|
||||
# duplicate it the numb of times needed to match the train loaders
|
||||
if self.overfit_batches > 0:
|
||||
num_loaders = len(dataloaders)
|
||||
train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader'))
|
||||
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]
|
||||
|
||||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
|
||||
def test_overfit_multiple_val_loaders(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx):
|
||||
output = self.layer(batch[0])
|
||||
loss = self.loss(batch, output)
|
||||
return {"x": loss}
|
||||
|
||||
def validation_epoch_end(self, outputs) -> None:
|
||||
pass
|
||||
|
||||
def val_dataloader(self):
|
||||
dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64))
|
||||
dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64))
|
||||
return [dl1, dl2]
|
||||
|
||||
model = TestModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
overfit_batches=1,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
Loading…
Reference in New Issue