From d856989120b078581f3f694fd7a1c036703f67a9 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 27 Feb 2020 02:31:40 +0100 Subject: [PATCH] split trainer tests (#956) * split trainer tests * Apply suggestions from code review * format string * add CI timeout --- .github/workflows/ci-testing.yml | 2 + tests/models/utils.py | 4 +- tests/{test_logging.py => test_loggers.py} | 0 tests/trainer/__init__.py | 0 tests/trainer/test_dataloaders.py | 324 +++++++++++++++++++++ tests/{ => trainer}/test_trainer.py | 316 +------------------- 6 files changed, 330 insertions(+), 316 deletions(-) rename tests/{test_logging.py => test_loggers.py} (100%) create mode 100644 tests/trainer/__init__.py create mode 100644 tests/trainer/test_dataloaders.py rename tests/{ => trainer}/test_trainer.py (70%) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 0ab91b0793..a081a8a060 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -14,6 +14,8 @@ jobs: python-version: [3.6, 3.7] requires: ['minimal', 'latest'] + # https://stackoverflow.com/a/59076067/4521646 + timeout-minutes: 20 steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} diff --git a/tests/models/utils.py b/tests/models/utils.py index 30a11a7aa3..834e556783 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -188,13 +188,13 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50): acc = torch.tensor(acc) acc = acc.item() - assert acc >= min_acc, f'this model is expected to get > {min_acc} in test set (it got {acc})' + assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})" def assert_ok_model_acc(trainer, key='test_acc', thr=0.4): # this model should get 0.80+ acc acc = trainer.training_tqdm_dict[key] - assert acc > thr, f'Model failed to get expected {thr} accuracy. {key} = {acc}' + assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}" def can_run_gpu_test(): diff --git a/tests/test_logging.py b/tests/test_loggers.py similarity index 100% rename from tests/test_logging.py rename to tests/test_loggers.py diff --git a/tests/trainer/__init__.py b/tests/trainer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py new file mode 100644 index 0000000000..95801da253 --- /dev/null +++ b/tests/trainer/test_dataloaders.py @@ -0,0 +1,324 @@ +import pytest + +import tests.models.utils as tutils +from pytorch_lightning import Trainer +from tests.models import ( + TestModelBase, + LightningTestModel, + LightEmptyTestStep, + LightValidationMultipleDataloadersMixin, + LightTestMultipleDataloadersMixin, + LightTestFitSingleTestDataloadersMixin, + LightTestFitMultipleTestDataloadersMixin, + LightValStepFitMultipleDataloadersMixin, + LightValStepFitSingleDataloaderMixin, + LightTrainDataloader, +) +from pytorch_lightning.utilities.debugging import MisconfigurationException + + +def test_multiple_val_dataloader(tmpdir): + """Verify multiple val_dataloader.""" + tutils.reset_seed() + + class CurrentTestModel( + LightTrainDataloader, + LightValidationMultipleDataloadersMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=1.0, + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # verify training completed + assert result == 1 + + # verify there are 2 val loaders + assert len(trainer.val_dataloaders) == 2, \ + 'Multiple val_dataloaders not initiated properly' + + # make sure predictions are good for each val set + for dataloader in trainer.val_dataloaders: + tutils.run_prediction(dataloader, trainer.model) + + +def test_multiple_test_dataloader(tmpdir): + """Verify multiple test_dataloader.""" + tutils.reset_seed() + + class CurrentTestModel( + LightTrainDataloader, + LightTestMultipleDataloadersMixin, + LightEmptyTestStep, + TestModelBase, + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # fit model + trainer = Trainer(**trainer_options) + trainer.fit(model) + trainer.test() + + # verify there are 2 val loaders + assert len(trainer.test_dataloaders) == 2, \ + 'Multiple test_dataloaders not initiated properly' + + # make sure predictions are good for each test set + for dataloader in trainer.test_dataloaders: + tutils.run_prediction(dataloader, trainer.model) + + # run the test method + trainer.test() + + +def test_train_dataloaders_passed_to_fit(tmpdir): + """ Verify that train dataloader can be passed to fit """ + tutils.reset_seed() + + class CurrentTestModel(LightTrainDataloader, TestModelBase): + pass + + hparams = tutils.get_hparams() + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # only train passed to fit + model = CurrentTestModel(hparams) + trainer = Trainer(**trainer_options) + fit_options = dict(train_dataloader=model._dataloader(train=True)) + results = trainer.fit(model, **fit_options) + + +def test_train_val_dataloaders_passed_to_fit(tmpdir): + """ Verify that train & val dataloader can be passed to fit """ + tutils.reset_seed() + + class CurrentTestModel( + LightTrainDataloader, + LightValStepFitSingleDataloaderMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_hparams() + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # train, val passed to fit + model = CurrentTestModel(hparams) + trainer = Trainer(**trainer_options) + fit_options = dict(train_dataloader=model._dataloader(train=True), + val_dataloaders=model._dataloader(train=False)) + + results = trainer.fit(model, **fit_options) + assert len(trainer.val_dataloaders) == 1, \ + f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + + +def test_all_dataloaders_passed_to_fit(tmpdir): + """ Verify train, val & test dataloader can be passed to fit """ + tutils.reset_seed() + + class CurrentTestModel( + LightTrainDataloader, + LightValStepFitSingleDataloaderMixin, + LightTestFitSingleTestDataloadersMixin, + LightEmptyTestStep, + TestModelBase, + ): + pass + + hparams = tutils.get_hparams() + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # train, val and test passed to fit + model = CurrentTestModel(hparams) + trainer = Trainer(**trainer_options) + fit_options = dict(train_dataloader=model._dataloader(train=True), + val_dataloaders=model._dataloader(train=False), + test_dataloaders=model._dataloader(train=False)) + + results = trainer.fit(model, **fit_options) + + trainer.test() + + assert len(trainer.val_dataloaders) == 1, \ + f"val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + assert len(trainer.test_dataloaders) == 1, \ + f"test_dataloaders` not initiated properly, got {trainer.test_dataloaders}" + + +def test_multiple_dataloaders_passed_to_fit(tmpdir): + """Verify that multiple val & test dataloaders can be passed to fit.""" + tutils.reset_seed() + + class CurrentTestModel( + LightningTestModel, + LightValStepFitMultipleDataloadersMixin, + LightTestFitMultipleTestDataloadersMixin, + ): + pass + + hparams = tutils.get_hparams() + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # train, multiple val and multiple test passed to fit + model = CurrentTestModel(hparams) + trainer = Trainer(**trainer_options) + fit_options = dict(train_dataloader=model._dataloader(train=True), + val_dataloaders=[model._dataloader(train=False), + model._dataloader(train=False)], + test_dataloaders=[model._dataloader(train=False), + model._dataloader(train=False)]) + results = trainer.fit(model, **fit_options) + trainer.test() + + assert len(trainer.val_dataloaders) == 2, \ + f"Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + assert len(trainer.test_dataloaders) == 2, \ + f"Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}" + + +def test_mixing_of_dataloader_options(tmpdir): + """Verify that dataloaders can be passed to fit""" + tutils.reset_seed() + + class CurrentTestModel( + LightTrainDataloader, + LightValStepFitSingleDataloaderMixin, + LightTestFitSingleTestDataloadersMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # fit model + trainer = Trainer(**trainer_options) + fit_options = dict(val_dataloaders=model._dataloader(train=False)) + results = trainer.fit(model, **fit_options) + + # fit model + trainer = Trainer(**trainer_options) + fit_options = dict(val_dataloaders=model._dataloader(train=False), + test_dataloaders=model._dataloader(train=False)) + _ = trainer.fit(model, **fit_options) + trainer.test() + + assert len(trainer.val_dataloaders) == 1, \ + f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + assert len(trainer.test_dataloaders) == 1, \ + f"test_dataloaders` not initiated properly, got {trainer.test_dataloaders}" + + +def test_inf_train_dataloader(tmpdir): + """Test inf train data loader (e.g. IterableDataset)""" + tutils.reset_seed() + + class CurrentTestModel(LightningTestModel): + def train_dataloader(self): + dataloader = self._dataloader(train=True) + + class CustomInfDataLoader: + def __init__(self, dataloader): + self.dataloader = dataloader + self.iter = iter(dataloader) + self.count = 0 + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + if self.count >= 5: + raise StopIteration + self.count = self.count + 1 + try: + return next(self.iter) + except StopIteration: + self.iter = iter(self.dataloader) + return next(self.iter) + + return CustomInfDataLoader(dataloader) + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # fit model + with pytest.raises(MisconfigurationException): + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + val_check_interval=0.5 + ) + trainer.fit(model) + + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + val_check_interval=50, + ) + result = trainer.fit(model) + + # verify training completed + assert result == 1 diff --git a/tests/test_trainer.py b/tests/trainer/test_trainer.py similarity index 70% rename from tests/test_trainer.py rename to tests/trainer/test_trainer.py index 7a8881907a..b2f4a99740 100644 --- a/tests/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -16,11 +16,6 @@ from tests.models import ( LightEmptyTestStep, LightValidationStepMixin, LightValidationMultipleDataloadersMixin, - LightTestMultipleDataloadersMixin, - LightTestFitSingleTestDataloadersMixin, - LightTestFitMultipleTestDataloadersMixin, - LightValStepFitMultipleDataloadersMixin, - LightValStepFitSingleDataloaderMixin, LightTrainDataloader, LightTestDataloader, LightValidationMixin, @@ -258,7 +253,7 @@ def test_model_checkpoint_options(tmp_path): # verify correct naming for i in range(0, len(losses)): - assert f'_ckpt_epoch_{i}.ckpt' in file_lists + assert f"_ckpt_epoch_{i}.ckpt" in file_lists save_dir = tmp_path / "2" save_dir.mkdir() @@ -307,7 +302,7 @@ def test_model_checkpoint_options(tmp_path): # make sure other files don't get deleted checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1) - open(f'{save_dir}/other_file.ckpt', 'a').close() + open(f"{save_dir}/other_file.ckpt", 'a').close() checkpoint_callback.save_function = mock_save_function trainer = Trainer() @@ -380,98 +375,6 @@ def test_model_freeze_unfreeze(): model.unfreeze() -def test_inf_train_dataloader(tmpdir): - """Test inf train data loader (e.g. IterableDataset)""" - tutils.reset_seed() - - class CurrentTestModel(LightningTestModel): - def train_dataloader(self): - dataloader = self._dataloader(train=True) - - class CustomInfDataLoader: - def __init__(self, dataloader): - self.dataloader = dataloader - self.iter = iter(dataloader) - self.count = 0 - - def __iter__(self): - self.count = 0 - return self - - def __next__(self): - if self.count >= 5: - raise StopIteration - self.count = self.count + 1 - try: - return next(self.iter) - except StopIteration: - self.iter = iter(self.dataloader) - return next(self.iter) - - return CustomInfDataLoader(dataloader) - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - # fit model - with pytest.raises(MisconfigurationException): - trainer = Trainer( - default_save_path=tmpdir, - max_epochs=1, - val_check_interval=0.5 - ) - trainer.fit(model) - - # logger file to get meta - trainer = Trainer( - default_save_path=tmpdir, - max_epochs=1, - val_check_interval=50, - ) - result = trainer.fit(model) - - # verify training completed - assert result == 1 - - -def test_multiple_val_dataloader(tmpdir): - """Verify multiple val_dataloader.""" - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightValidationMultipleDataloadersMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=1.0, - ) - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - - # verify training completed - assert result == 1 - - # verify there are 2 val loaders - assert len(trainer.val_dataloaders) == 2, \ - 'Multiple val_dataloaders not initiated properly' - - # make sure predictions are good for each val set - for dataloader in trainer.val_dataloaders: - tutils.run_prediction(dataloader, trainer.model) - - def test_resume_from_checkpoint_epoch_restored(tmpdir): """Verify resuming from checkpoint runs the right number of epochs""" import types @@ -540,221 +443,6 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir): assert state['global_step'] + next_model.num_batches_seen == training_batches * 4 -def test_multiple_test_dataloader(tmpdir): - """Verify multiple test_dataloader.""" - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightTestMultipleDataloadersMixin, - LightEmptyTestStep, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # fit model - trainer = Trainer(**trainer_options) - trainer.fit(model) - trainer.test() - - # verify there are 2 val loaders - assert len(trainer.test_dataloaders) == 2, \ - 'Multiple test_dataloaders not initiated properly' - - # make sure predictions are good for each test set - for dataloader in trainer.test_dataloaders: - tutils.run_prediction(dataloader, trainer.model) - - # run the test method - trainer.test() - - -def test_train_dataloaders_passed_to_fit(tmpdir): - """ Verify that train dataloader can be passed to fit """ - tutils.reset_seed() - - class CurrentTestModel(LightTrainDataloader, TestModelBase): - pass - - hparams = tutils.get_hparams() - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # only train passed to fit - model = CurrentTestModel(hparams) - trainer = Trainer(**trainer_options) - fit_options = dict(train_dataloader=model._dataloader(train=True)) - results = trainer.fit(model, **fit_options) - - -def test_train_val_dataloaders_passed_to_fit(tmpdir): - """ Verify that train & val dataloader can be passed to fit """ - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # train, val passed to fit - model = CurrentTestModel(hparams) - trainer = Trainer(**trainer_options) - fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False)) - - results = trainer.fit(model, **fit_options) - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - - -def test_all_dataloaders_passed_to_fit(tmpdir): - """ Verify train, val & test dataloader can be passed to fit """ - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - LightTestFitSingleTestDataloadersMixin, - LightEmptyTestStep, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # train, val and test passed to fit - model = CurrentTestModel(hparams) - trainer = Trainer(**trainer_options) - fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False), - test_dataloaders=model._dataloader(train=False)) - - results = trainer.fit(model, **fit_options) - - trainer.test() - - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 1, \ - f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - - -def test_multiple_dataloaders_passed_to_fit(tmpdir): - """ Verify that multiple val & test dataloaders can be passed to fit """ - tutils.reset_seed() - - class CurrentTestModel( - LightningTestModel, - LightValStepFitMultipleDataloadersMixin, - LightTestFitMultipleTestDataloadersMixin, - ): - pass - - hparams = tutils.get_hparams() - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # train, multiple val and multiple test passed to fit - model = CurrentTestModel(hparams) - trainer = Trainer(**trainer_options) - fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=[model._dataloader(train=False), - model._dataloader(train=False)], - test_dataloaders=[model._dataloader(train=False), - model._dataloader(train=False)]) - results = trainer.fit(model, **fit_options) - trainer.test() - - assert len(trainer.val_dataloaders) == 2, \ - f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 2, \ - f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - - -def test_mixing_of_dataloader_options(tmpdir): - """Verify that dataloaders can be passed to fit""" - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - LightTestFitSingleTestDataloadersMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # fit model - trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloaders=model._dataloader(train=False)) - results = trainer.fit(model, **fit_options) - - # fit model - trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloaders=model._dataloader(train=False), - test_dataloaders=model._dataloader(train=False)) - _ = trainer.fit(model, **fit_options) - trainer.test() - - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 1, \ - f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - - def _init_steps_model(): """private method for initializing a model with 5% train epochs""" tutils.reset_seed()