From 2a2f303ae91a4a17b1cd8127e5b811e38cb2d978 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 5 May 2020 18:31:15 +0200 Subject: [PATCH] Tests: refactor trainer dataloaders (#1690) * refactor default model * drop redundant seeds * refactor dataloaders tests * fix multiple * fix conf * flake8 * Apply suggestions from code review Co-authored-by: William Falcon Co-authored-by: William Falcon --- pytorch_lightning/trainer/data_loading.py | 20 +- tests/base/eval_model_test_dataloaders.py | 8 + tests/base/eval_model_test_epoch_ends.py | 36 +++ tests/base/eval_model_test_steps.py | 4 + tests/base/eval_model_train_dataloaders.py | 11 + tests/base/eval_model_utils.py | 22 ++ tests/base/eval_model_valid_dataloaders.py | 5 + tests/trainer/test_dataloaders.py | 309 +++++++-------------- 8 files changed, 195 insertions(+), 220 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index b3e15024c4..00b37dfa86 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -46,8 +46,8 @@ def _has_len(dataloader: DataLoader) -> bool: try: # try getting the length if len(dataloader) == 0: - raise ValueError('Dataloader returned 0 length. Please make sure' - ' that your Dataloader atleast returns 1 batch') + raise ValueError('`Dataloader` returned 0 length.' + ' Please make sure that your Dataloader at least returns 1 batch') return True except TypeError: return False @@ -186,10 +186,10 @@ class TrainerDataLoadingMixin(ABC): self.val_check_batch = float('inf') else: raise MisconfigurationException( - 'When using an infinite DataLoader (e.g. with an IterableDataset or when ' - 'DataLoader does not implement `__len__`) for `train_dataloader`, ' - '`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies ' - 'checking validation every k training batches.') + 'When using an infinite DataLoader (e.g. with an IterableDataset' + ' or when DataLoader does not implement `__len__`) for `train_dataloader`,' + ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' + ' checking validation every k training batches.') else: self._percent_range_check('val_check_interval') @@ -240,9 +240,9 @@ class TrainerDataLoadingMixin(ABC): num_batches = int(num_batches * percent_check) elif percent_check not in (0.0, 1.0): raise MisconfigurationException( - 'When using an infinite DataLoader (e.g. with an IterableDataset or when ' - f'DataLoader does not implement `__len__`) for `{mode}_dataloader`, ' - f'`Trainer({mode}_percent_check)` must be `0.0` or `1.0`.') + 'When using an infinite DataLoader (e.g. with an IterableDataset' + f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,' + f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.') return num_batches, dataloaders def reset_val_dataloader(self, model: LightningModule) -> None: @@ -252,7 +252,7 @@ class TrainerDataLoadingMixin(ABC): model: The current `LightningModule` """ if self.is_overriden('validation_step'): - self.num_val_batches, self.val_dataloaders =\ + self.num_val_batches, self.val_dataloaders = \ self._reset_eval_dataloader(model, 'val') def reset_test_dataloader(self, model) -> None: diff --git a/tests/base/eval_model_test_dataloaders.py b/tests/base/eval_model_test_dataloaders.py index 158b398545..fdab56994a 100644 --- a/tests/base/eval_model_test_dataloaders.py +++ b/tests/base/eval_model_test_dataloaders.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from tests.base.eval_model_utils import CustomInfDataloader + class TestDataloaderVariations(ABC): @@ -10,5 +12,11 @@ class TestDataloaderVariations(ABC): def test_dataloader(self): return self.dataloader(train=False) + def test_dataloader__infinite(self): + return CustomInfDataloader(self.dataloader(train=False)) + def test_dataloader__empty(self): return None + + def test_dataloader__multiple(self): + return [self.dataloader(train=False), self.dataloader(train=False)] diff --git a/tests/base/eval_model_test_epoch_ends.py b/tests/base/eval_model_test_epoch_ends.py index 5279e6a9fc..fa3c3f7f4a 100644 --- a/tests/base/eval_model_test_epoch_ends.py +++ b/tests/base/eval_model_test_epoch_ends.py @@ -37,3 +37,39 @@ class TestEpochEndVariations(ABC): metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} result = {'progress_bar': metrics_dict, 'log': metrics_dict} return result + + def test_epoch_end__multiple_dataloaders(self, outputs): + """ + Called at the end of validation to aggregate outputs + :param outputs: list of individual outputs of each validation step + :return: + """ + # if returned a scalar from test_step, outputs is a list of tensor scalars + # we return just the average in this case (if we want) + # return torch.stack(outputs).mean() + test_loss_mean = 0 + test_acc_mean = 0 + i = 0 + for dl_output in outputs: + for output in dl_output: + test_loss = output['test_loss'] + + # reduce manually when using dp + if self.trainer.use_dp: + test_loss = torch.mean(test_loss) + test_loss_mean += test_loss + + # reduce manually when using dp + test_acc = output['test_acc'] + if self.trainer.use_dp: + test_acc = torch.mean(test_acc) + + test_acc_mean += test_acc + i += 1 + + test_loss_mean /= i + test_acc_mean /= i + + tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} + result = {'progress_bar': tqdm_dict} + return result diff --git a/tests/base/eval_model_test_steps.py b/tests/base/eval_model_test_steps.py index b4c80cff06..bf57c2815b 100644 --- a/tests/base/eval_model_test_steps.py +++ b/tests/base/eval_model_test_steps.py @@ -8,6 +8,7 @@ class TestStepVariations(ABC): """ Houses all variations of test steps """ + def test_step(self, batch, batch_idx, *args, **kwargs): """ Default, baseline test_step @@ -87,3 +88,6 @@ class TestStepVariations(ABC): f'test_acc_{dataloader_idx}': test_acc, }) return output + + def test_step__empty(self, batch, batch_idx, *args, **kwargs): + return {} diff --git a/tests/base/eval_model_train_dataloaders.py b/tests/base/eval_model_train_dataloaders.py index 3d547a8363..ded46de3d6 100644 --- a/tests/base/eval_model_train_dataloaders.py +++ b/tests/base/eval_model_train_dataloaders.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from tests.base.eval_model_utils import CustomInfDataloader + class TrainDataloaderVariations(ABC): @@ -9,3 +11,12 @@ class TrainDataloaderVariations(ABC): def train_dataloader(self): return self.dataloader(train=True) + + def train_dataloader__infinite(self): + return CustomInfDataloader(self.dataloader(train=True)) + + def train_dataloader__zero_length(self): + dataloader = self.dataloader(train=True) + dataloader.dataset.data = dataloader.dataset.data[:0] + dataloader.dataset.targets = dataloader.dataset.targets[:0] + return dataloader diff --git a/tests/base/eval_model_utils.py b/tests/base/eval_model_utils.py index e1a40f95b8..d3eed3cb8d 100644 --- a/tests/base/eval_model_utils.py +++ b/tests/base/eval_model_utils.py @@ -26,3 +26,25 @@ class ModelTemplateUtils: else: # if it is 2level deep -> per dataloader and per batch val = sum(out[name] for out in output) / len(output) return val + + +class CustomInfDataloader: + + def __init__(self, dataloader): + self.dataloader = dataloader + self.iter = iter(dataloader) + self.count = 0 + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + if self.count >= 50: + raise StopIteration + self.count = self.count + 1 + try: + return next(self.iter) + except StopIteration: + self.iter = iter(self.dataloader) + return next(self.iter) diff --git a/tests/base/eval_model_valid_dataloaders.py b/tests/base/eval_model_valid_dataloaders.py index 72b5afccee..2b760e1308 100644 --- a/tests/base/eval_model_valid_dataloaders.py +++ b/tests/base/eval_model_valid_dataloaders.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from tests.base.eval_model_utils import CustomInfDataloader + class ValDataloaderVariations(ABC): @@ -13,3 +15,6 @@ class ValDataloaderVariations(ABC): def val_dataloader__multiple(self): return [self.dataloader(train=False), self.dataloader(train=False)] + + def val_dataloader__infinite(self): + return CustomInfDataloader(self.dataloader(train=False)) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d0a6dd869a..c249b834ef 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -8,22 +8,7 @@ from torch.utils.data.dataset import Subset import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import ( - TestModelBase, - LightningTestModel, - LightEmptyTestStep, - LightValidationMultipleDataloadersMixin, - LightTestMultipleDataloadersMixin, - LightTestFitSingleTestDataloadersMixin, - LightTestFitMultipleTestDataloadersMixin, - LightValStepFitMultipleDataloadersMixin, - LightValStepFitSingleDataloaderMixin, - LightTrainDataloader, - LightInfTrainDataloader, - LightInfValDataloader, - LightInfTestDataloader, - LightZeroLenDataloader -) +from tests.base import EvalModelTemplate @pytest.mark.parametrize("dataloader_options", [ @@ -34,14 +19,7 @@ from tests.base import ( ]) def test_dataloader_config_errors(tmpdir, dataloader_options): - class CurrentTestModel( - LightTrainDataloader, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) # fit model trainer = Trainer( @@ -57,15 +35,9 @@ def test_dataloader_config_errors(tmpdir, dataloader_options): def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" - class CurrentTestModel( - LightTrainDataloader, - LightValidationMultipleDataloadersMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) + model.val_dataloader = model.val_dataloader__multiple + model.validation_step = model.validation_step__multiple_dataloaders # fit model trainer = Trainer( @@ -91,16 +63,9 @@ def test_multiple_val_dataloader(tmpdir): def test_multiple_test_dataloader(tmpdir): """Verify multiple test_dataloader.""" - class CurrentTestModel( - LightTrainDataloader, - LightTestMultipleDataloadersMixin, - LightEmptyTestStep, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) + model.test_dataloader = model.test_dataloader__multiple + model.test_step = model.test_step__multiple_dataloaders # fit model trainer = Trainer( @@ -127,20 +92,16 @@ def test_multiple_test_dataloader(tmpdir): def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ - class CurrentTestModel(LightTrainDataloader, TestModelBase): - pass - - hparams = tutils.get_default_hparams() - # only train passed to fit - model = CurrentTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) - result = trainer.fit(model, train_dataloader=model._dataloader(train=True)) + fit_options = dict(train_dataloader=model.dataloader(train=True)) + result = trainer.fit(model, **fit_options) assert result == 1 @@ -148,26 +109,18 @@ def test_train_dataloader_passed_to_fit(tmpdir): def test_train_val_dataloaders_passed_to_fit(tmpdir): """ Verify that train & val dataloader can be passed to fit """ - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() - # train, val passed to fit - model = CurrentTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) - result = trainer.fit(model, - train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False)) + fit_options = dict(train_dataloader=model.dataloader(train=True), + val_dataloaders=model.dataloader(train=False)) + + result = trainer.fit(model, **fit_options) assert result == 1 assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -176,31 +129,21 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir): def test_all_dataloaders_passed_to_fit(tmpdir): """Verify train, val & test dataloader(s) can be passed to fit and test method""" - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - LightTestFitSingleTestDataloadersMixin, - LightEmptyTestStep, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() + model = EvalModelTemplate(tutils.get_default_hparams()) # train, val and test passed to fit - model = CurrentTestModel(hparams) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) + fit_options = dict(train_dataloader=model.dataloader(train=True), + val_dataloaders=model.dataloader(train=False)) + test_options = dict(test_dataloaders=model.dataloader(train=False)) - result = trainer.fit(model, - train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False)) - - trainer.test(test_dataloaders=model._dataloader(train=False)) + result = trainer.fit(model, **fit_options) + trainer.test(**test_options) assert result == 1 assert len(trainer.val_dataloaders) == 1, \ @@ -212,32 +155,25 @@ def test_all_dataloaders_passed_to_fit(tmpdir): def test_multiple_dataloaders_passed_to_fit(tmpdir): """Verify that multiple val & test dataloaders can be passed to fit.""" - class CurrentTestModel( - LightningTestModel, - LightValStepFitMultipleDataloadersMixin, - LightTestFitMultipleTestDataloadersMixin, - ): - pass - - hparams = tutils.get_default_hparams() + model = EvalModelTemplate(tutils.get_default_hparams()) + model.validation_step = model.validation_step__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders # train, multiple val and multiple test passed to fit - model = CurrentTestModel(hparams) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) + fit_options = dict(train_dataloader=model.dataloader(train=True), + val_dataloaders=[model.dataloader(train=False), + model.dataloader(train=False)]) + test_options = dict(test_dataloaders=[model.dataloader(train=False), + model.dataloader(train=False)]) - results = trainer.fit( - model, - train_dataloader=model._dataloader(train=True), - val_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)], - ) - assert results - - trainer.test(test_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)]) + trainer.fit(model, **fit_options) + trainer.test(**test_options) assert len(trainer.val_dataloaders) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -248,16 +184,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir): def test_mixing_of_dataloader_options(tmpdir): """Verify that dataloaders can be passed to fit""" - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - LightTestFitSingleTestDataloadersMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) trainer_options = dict( default_root_dir=tmpdir, @@ -268,17 +195,14 @@ def test_mixing_of_dataloader_options(tmpdir): # fit model trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloaders=model._dataloader(train=False)) - results = trainer.fit(model, **fit_options) + results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) assert results # fit model trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloaders=model._dataloader(train=False)) - test_options = dict(test_dataloaders=model._dataloader(train=False)) - - _ = trainer.fit(model, **fit_options) - trainer.test(**test_options) + results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) + assert results + trainer.test(test_dataloaders=model.dataloader(train=False)) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -286,72 +210,68 @@ def test_mixing_of_dataloader_options(tmpdir): f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' -def test_inf_train_dataloader(tmpdir): +def test_train_inf_dataloader_error(tmpdir): + """Test inf train data loader (e.g. IterableDataset)""" + model = EvalModelTemplate(tutils.get_default_hparams()) + model.train_dataloader = model.train_dataloader__infinite + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5) + + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + trainer.fit(model) + + +def test_val_inf_dataloader_error(tmpdir): + """Test inf train data loader (e.g. IterableDataset)""" + model = EvalModelTemplate(tutils.get_default_hparams()) + model.val_dataloader = model.val_dataloader__infinite + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.5) + + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + trainer.fit(model) + + +def test_test_inf_dataloader_error(tmpdir): + """Test inf train data loader (e.g. IterableDataset)""" + model = EvalModelTemplate(tutils.get_default_hparams()) + model.test_dataloader = model.test_dataloader__infinite + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, test_percent_check=0.5) + + with pytest.raises(MisconfigurationException, match='infinite DataLoader'): + trainer.test(model) + + +@pytest.mark.parametrize('check_interval', [50, 1.0]) +def test_inf_train_dataloader(tmpdir, check_interval): """Test inf train data loader (e.g. IterableDataset)""" - class CurrentTestModel( - LightInfTrainDataloader, - LightningTestModel - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) - - # fit model - with pytest.raises(MisconfigurationException): - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - val_check_interval=0.5 - ) - trainer.fit(model) + model = EvalModelTemplate(tutils.get_default_hparams()) + model.train_dataloader = model.train_dataloader__infinite trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_check_interval=50 + train_check_interval=check_interval, ) result = trainer.fit(model) - - # verify training completed - assert result == 1 - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1 - ) - result = trainer.fit(model) - # verify training completed assert result == 1 -def test_inf_val_dataloader(tmpdir): +@pytest.mark.parametrize('check_interval', [1.0]) +def test_inf_val_dataloader(tmpdir, check_interval): """Test inf val data loader (e.g. IterableDataset)""" - class CurrentTestModel( - LightInfValDataloader, - LightningTestModel - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) - - # fit model - with pytest.raises(MisconfigurationException): - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - val_percent_check=0.5 - ) - trainer.fit(model) + model = EvalModelTemplate(tutils.get_default_hparams()) + model.val_dataloader = model.val_dataloader__infinite # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1 + max_epochs=1, + val_check_interval=check_interval, ) result = trainer.fit(model) @@ -359,35 +279,20 @@ def test_inf_val_dataloader(tmpdir): assert result == 1 -def test_inf_test_dataloader(tmpdir): +@pytest.mark.parametrize('check_interval', [50, 1.0]) +def test_inf_test_dataloader(tmpdir, check_interval): """Test inf test data loader (e.g. IterableDataset)""" - class CurrentTestModel( - LightInfTestDataloader, - LightningTestModel, - LightTestFitSingleTestDataloadersMixin - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) - - # fit model - with pytest.raises(MisconfigurationException): - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - test_percent_check=0.5 - ) - trainer.test(model) + model = EvalModelTemplate(tutils.get_default_hparams()) + model.test_dataloader = model.test_dataloader__infinite # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1 + max_epochs=1, + test_check_interval=check_interval, ) result = trainer.fit(model) - trainer.test(model) # verify training completed assert result == 1 @@ -396,14 +301,8 @@ def test_inf_test_dataloader(tmpdir): def test_error_on_zero_len_dataloader(tmpdir): """ Test that error is raised if a zero-length dataloader is defined """ - class CurrentTestModel( - LightZeroLenDataloader, - LightningTestModel - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) + model.train_dataloader = model.train_dataloader__zero_length # fit model with pytest.raises(ValueError): @@ -419,29 +318,22 @@ def test_error_on_zero_len_dataloader(tmpdir): def test_warning_with_few_workers(tmpdir): """ Test that error is raised if dataloader with only a few workers is used """ - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - LightTestFitSingleTestDataloadersMixin, - LightEmptyTestStep, - TestModelBase, - ): - pass + model = EvalModelTemplate(tutils.get_default_hparams()) - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) - - fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False)) - test_options = dict(test_dataloaders=model._dataloader(train=False)) - - trainer = Trainer( + # logger file to get meta + trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) + fit_options = dict(train_dataloader=model.dataloader(train=True), + val_dataloaders=model.dataloader(train=False)) + test_options = dict(test_dataloaders=model.dataloader(train=False)) + + trainer = Trainer(**trainer_options) + # fit model with pytest.warns(UserWarning, match='train'): trainer.fit(model, **fit_options) @@ -491,10 +383,7 @@ def test_batch_size_smaller_than_num_gpus(): num_gpus = 3 batch_size = 3 - class CurrentTestModel( - LightTrainDataloader, - TestModelBase, - ): + class CurrentTestModel(EvalModelTemplate): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)