tests for multiple optimizers and dataloader combinations (#3937)
* added tests for multiple optimizers and dataloaders * added tests for multiple optimizers and dataloaders * added tests for multiple optimizers and dataloaders
This commit is contained in:
parent
05cb6fcc58
commit
575e01be82
|
@ -29,6 +29,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger
|
|||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
|
||||
from copy import deepcopy
|
||||
from typing import Iterable
|
||||
|
||||
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
|
||||
try:
|
||||
|
@ -336,7 +337,19 @@ class TrainerDataLoadingMixin(ABC):
|
|||
The dataloader
|
||||
"""
|
||||
dataloader = dataloader_fx()
|
||||
dataloader = self._flatten_dl_only(dataloader)
|
||||
|
||||
if self.accelerator_backend is not None:
|
||||
self.accelerator_backend.barrier('get_dataloaders')
|
||||
return dataloader
|
||||
|
||||
def _flatten_dl_only(self, dataloaders):
|
||||
# handles user error when they return:
|
||||
# return dl1, dl2 vs return (dl1, dl2)
|
||||
if isinstance(dataloaders, tuple):
|
||||
all_dls = [isinstance(x, Iterable) for x in dataloaders]
|
||||
all_dls = all(all_dls)
|
||||
if all_dls:
|
||||
dataloaders = list(dataloaders)
|
||||
|
||||
return dataloaders
|
||||
|
|
|
@ -144,14 +144,22 @@ class EvaluationLoop(object):
|
|||
# make dataloader_idx arg in validation_step optional
|
||||
args = [batch, batch_idx]
|
||||
|
||||
multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1)
|
||||
multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1)
|
||||
multiple_val_loaders = (not test_mode and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1)
|
||||
multiple_test_loaders = (test_mode and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1)
|
||||
|
||||
if multiple_test_loaders or multiple_val_loaders:
|
||||
args.append(dataloader_idx)
|
||||
|
||||
return args
|
||||
|
||||
def _get_num_dataloaders(self, dataloaders):
|
||||
# case where user does:
|
||||
# return dl1, dl2
|
||||
length = len(dataloaders)
|
||||
if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)):
|
||||
length = len(dataloaders[0])
|
||||
return length
|
||||
|
||||
def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
|
||||
# configure args
|
||||
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
from pytorch_lightning import Trainer
|
||||
from tests.base.boring_model import BoringModel
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class RandomDatasetA(Dataset):
|
||||
def __init__(self, size, length):
|
||||
self.len = length
|
||||
self.data = torch.randn(length, size)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return torch.zeros(1)
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
|
||||
class RandomDatasetB(Dataset):
|
||||
def __init__(self, size, length):
|
||||
self.len = length
|
||||
self.data = torch.randn(length, size)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return torch.ones(1)
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
|
||||
def test_multiple_eval_dataloaders_tuple(tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx):
|
||||
if dataloader_idx == 0:
|
||||
assert batch.sum() == 0
|
||||
elif dataloader_idx == 1:
|
||||
assert batch.sum() == 11
|
||||
else:
|
||||
raise Exception('should only have two dataloaders')
|
||||
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
# outputs should be an array with an entry per optimizer
|
||||
assert len(outputs) == 2
|
||||
|
||||
def val_dataloader(self):
|
||||
dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11)
|
||||
dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11)
|
||||
return [dl1, dl2]
|
||||
|
||||
model = TestModel()
|
||||
model.validation_epoch_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_multiple_eval_dataloaders_list(tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx):
|
||||
if dataloader_idx == 0:
|
||||
assert batch.sum() == 0
|
||||
elif dataloader_idx == 1:
|
||||
assert batch.sum() == 11
|
||||
else:
|
||||
raise Exception('should only have two dataloaders')
|
||||
|
||||
def val_dataloader(self):
|
||||
dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11)
|
||||
dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11)
|
||||
return dl1, dl2
|
||||
|
||||
model = TestModel()
|
||||
model.validation_epoch_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_multiple_optimizers_multiple_dataloaders(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
def on_train_epoch_start(self) -> None:
|
||||
self.opt_0_seen = False
|
||||
self.opt_1_seen = False
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
if optimizer_idx == 0:
|
||||
self.opt_0_seen = True
|
||||
elif optimizer_idx == 1:
|
||||
self.opt_1_seen = True
|
||||
else:
|
||||
raise Exception('should only have two optimizers')
|
||||
|
||||
self.training_step_called = True
|
||||
loss = self.step(batch[0])
|
||||
return loss
|
||||
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
# outputs should be an array with an entry per optimizer
|
||||
assert len(outputs) == 2
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx):
|
||||
if dataloader_idx == 0:
|
||||
assert batch.sum() == 0
|
||||
elif dataloader_idx == 1:
|
||||
assert batch.sum() == 11
|
||||
else:
|
||||
raise Exception('should only have two dataloaders')
|
||||
|
||||
def val_dataloader(self):
|
||||
dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11)
|
||||
dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11)
|
||||
return dl1, dl2
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
return optimizer, optimizer_2
|
||||
|
||||
model = TestModel()
|
||||
model.validation_epoch_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
assert model.opt_0_seen
|
||||
assert model.opt_1_seen
|
|
@ -0,0 +1,50 @@
|
|||
from pytorch_lightning import Trainer
|
||||
from tests.base.boring_model import BoringModel
|
||||
import torch
|
||||
|
||||
|
||||
def test_multiple_optimizers(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
def on_train_epoch_start(self) -> None:
|
||||
self.opt_0_seen = False
|
||||
self.opt_1_seen = False
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
if optimizer_idx == 0:
|
||||
self.opt_0_seen = True
|
||||
elif optimizer_idx == 1:
|
||||
self.opt_1_seen = True
|
||||
else:
|
||||
raise Exception('should only have two optimizers')
|
||||
|
||||
self.training_step_called = True
|
||||
loss = self.step(batch[0])
|
||||
return loss
|
||||
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
# outputs should be an array with an entry per optimizer
|
||||
assert len(outputs) == 2
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
return optimizer, optimizer_2
|
||||
|
||||
model = TestModel()
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
assert model.opt_0_seen
|
||||
assert model.opt_1_seen
|
Loading…
Reference in New Issue