diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 5f2cb3a894..449102fba2 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from copy import deepcopy +from typing import Iterable TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() try: @@ -336,7 +337,19 @@ class TrainerDataLoadingMixin(ABC): The dataloader """ dataloader = dataloader_fx() + dataloader = self._flatten_dl_only(dataloader) if self.accelerator_backend is not None: self.accelerator_backend.barrier('get_dataloaders') return dataloader + + def _flatten_dl_only(self, dataloaders): + # handles user error when they return: + # return dl1, dl2 vs return (dl1, dl2) + if isinstance(dataloaders, tuple): + all_dls = [isinstance(x, Iterable) for x in dataloaders] + all_dls = all(all_dls) + if all_dls: + dataloaders = list(dataloaders) + + return dataloaders diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8d65561ad0..17630953a8 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -144,14 +144,22 @@ class EvaluationLoop(object): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] - multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1) - multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1) + multiple_val_loaders = (not test_mode and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1) + multiple_test_loaders = (test_mode and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args + def _get_num_dataloaders(self, dataloaders): + # case where user does: + # return dl1, dl2 + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length + def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): # configure args args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) diff --git a/tests/trainer/dynamic_args/__init__.py b/tests/trainer/dynamic_args/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py new file mode 100644 index 0000000000..385c18a50b --- /dev/null +++ b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -0,0 +1,154 @@ +from pytorch_lightning import Trainer +from tests.base.boring_model import BoringModel +import torch +from torch.utils.data import Dataset + + +class RandomDatasetA(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return torch.zeros(1) + + def __len__(self): + return self.len + + +class RandomDatasetB(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return torch.ones(1) + + def __len__(self): + return self.len + + +def test_multiple_eval_dataloaders_tuple(tmpdir): + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx, dataloader_idx): + if dataloader_idx == 0: + assert batch.sum() == 0 + elif dataloader_idx == 1: + assert batch.sum() == 11 + else: + raise Exception('should only have two dataloaders') + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def val_dataloader(self): + dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) + dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) + return [dl1, dl2] + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + row_log_interval=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_multiple_eval_dataloaders_list(tmpdir): + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx, dataloader_idx): + if dataloader_idx == 0: + assert batch.sum() == 0 + elif dataloader_idx == 1: + assert batch.sum() == 11 + else: + raise Exception('should only have two dataloaders') + + def val_dataloader(self): + dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) + dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) + return dl1, dl2 + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + row_log_interval=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_multiple_optimizers_multiple_dataloaders(tmpdir): + """ + Tests that only training_step can be used + """ + class TestModel(BoringModel): + def on_train_epoch_start(self) -> None: + self.opt_0_seen = False + self.opt_1_seen = False + + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 0: + self.opt_0_seen = True + elif optimizer_idx == 1: + self.opt_1_seen = True + else: + raise Exception('should only have two optimizers') + + self.training_step_called = True + loss = self.step(batch[0]) + return loss + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def validation_step(self, batch, batch_idx, dataloader_idx): + if dataloader_idx == 0: + assert batch.sum() == 0 + elif dataloader_idx == 1: + assert batch.sum() == 11 + else: + raise Exception('should only have two dataloaders') + + def val_dataloader(self): + dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) + dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) + return dl1, dl2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + row_log_interval=1, + weights_summary=None, + ) + + trainer.fit(model) + assert model.opt_0_seen + assert model.opt_1_seen diff --git a/tests/trainer/dynamic_args/test_multiple_optimizers.py b/tests/trainer/dynamic_args/test_multiple_optimizers.py new file mode 100644 index 0000000000..eb95f2a43a --- /dev/null +++ b/tests/trainer/dynamic_args/test_multiple_optimizers.py @@ -0,0 +1,50 @@ +from pytorch_lightning import Trainer +from tests.base.boring_model import BoringModel +import torch + + +def test_multiple_optimizers(tmpdir): + """ + Tests that only training_step can be used + """ + class TestModel(BoringModel): + def on_train_epoch_start(self) -> None: + self.opt_0_seen = False + self.opt_1_seen = False + + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 0: + self.opt_0_seen = True + elif optimizer_idx == 1: + self.opt_1_seen = True + else: + raise Exception('should only have two optimizers') + + self.training_step_called = True + loss = self.step(batch[0]) + return loss + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + row_log_interval=1, + weights_summary=None, + ) + + trainer.fit(model) + assert model.opt_0_seen + assert model.opt_1_seen