import pytest import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities.debugging import MisconfigurationException from tests.base import ( TestModelBase, LightningTestModel, LightEmptyTestStep, LightValidationMultipleDataloadersMixin, LightTestMultipleDataloadersMixin, LightTestFitSingleTestDataloadersMixin, LightTestFitMultipleTestDataloadersMixin, LightValStepFitMultipleDataloadersMixin, LightValStepFitSingleDataloaderMixin, LightTrainDataloader, LightInfTrainDataloader, LightInfValDataloader, LightInfTestDataloader ) def test_dataloader_config_errors(tmpdir): tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # percent check < 0 # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, train_percent_check=-0.1, ) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # percent check > 1 # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, train_percent_check=1.1, ) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # int val_check_interval > num batches # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_check_interval=10000 ) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # float val_check_interval > 1 # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_check_interval=1.1 ) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValidationMultipleDataloadersMixin, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=1.0, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # verify training completed assert result == 1 # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, \ 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: tutils.run_prediction(dataloader, trainer.model) def test_multiple_test_dataloader(tmpdir): """Verify multiple test_dataloader.""" tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightTestMultipleDataloadersMixin, LightEmptyTestStep, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model) trainer.test() # verify there are 2 val loaders assert len(trainer.test_dataloaders) == 2, \ 'Multiple test_dataloaders not initiated properly' # make sure predictions are good for each test set for dataloader in trainer.test_dataloaders: tutils.run_prediction(dataloader, trainer.model) # run the test method trainer.test() def test_train_dataloaders_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel(LightTrainDataloader, TestModelBase): pass hparams = tutils.get_default_hparams() # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # only train passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True)) result = trainer.fit(model, **fit_options) assert result == 1 def test_train_val_dataloaders_passed_to_fit(tmpdir): """ Verify that train & val dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValStepFitSingleDataloaderMixin, TestModelBase, ): pass hparams = tutils.get_default_hparams() # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # train, val passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False)) result = trainer.fit(model, **fit_options) assert result == 1 assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' def test_all_dataloaders_passed_to_fit(tmpdir): """Verify train, val & test dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValStepFitSingleDataloaderMixin, LightTestFitSingleTestDataloadersMixin, LightEmptyTestStep, TestModelBase, ): pass hparams = tutils.get_default_hparams() # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # train, val and test passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False), test_dataloaders=model._dataloader(train=False)) result = trainer.fit(model, **fit_options) trainer.test() assert result == 1 assert len(trainer.val_dataloaders) == 1, \ f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_multiple_dataloaders_passed_to_fit(tmpdir): """Verify that multiple val & test dataloaders can be passed to fit.""" tutils.reset_seed() class CurrentTestModel( LightningTestModel, LightValStepFitMultipleDataloadersMixin, LightTestFitMultipleTestDataloadersMixin, ): pass hparams = tutils.get_default_hparams() # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # train, multiple val and multiple test passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)], test_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)]) results = trainer.fit(model, **fit_options) trainer.test() assert len(trainer.val_dataloaders) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 2, \ f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_mixing_of_dataloader_options(tmpdir): """Verify that dataloaders can be passed to fit""" tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValStepFitSingleDataloaderMixin, LightTestFitSingleTestDataloadersMixin, TestModelBase, ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # fit model trainer = Trainer(**trainer_options) fit_options = dict(val_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) # fit model trainer = Trainer(**trainer_options) fit_options = dict(val_dataloaders=model._dataloader(train=False), test_dataloaders=model._dataloader(train=False)) _ = trainer.fit(model, **fit_options) trainer.test() assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_inf_train_dataloader(tmpdir): """Test inf train data loader (e.g. IterableDataset)""" tutils.reset_seed() class CurrentTestModel( LightInfTrainDataloader, LightningTestModel ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # fit model with pytest.raises(MisconfigurationException): trainer = Trainer( default_save_path=tmpdir, max_epochs=1, val_check_interval=0.5 ) trainer.fit(model) # logger file to get meta trainer = Trainer( default_save_path=tmpdir, max_epochs=1, val_check_interval=50 ) result = trainer.fit(model) # verify training completed assert result == 1 def test_inf_val_dataloader(tmpdir): """Test inf val data loader (e.g. IterableDataset)""" tutils.reset_seed() class CurrentTestModel( LightInfValDataloader, LightningTestModel ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # fit model with pytest.raises(MisconfigurationException): trainer = Trainer( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.5 ) trainer.fit(model) # logger file to get meta trainer = Trainer( default_save_path=tmpdir, max_epochs=1 ) result = trainer.fit(model) # verify training completed assert result == 1 def test_inf_test_dataloader(tmpdir): """Test inf test data loader (e.g. IterableDataset)""" tutils.reset_seed() class CurrentTestModel( LightInfTestDataloader, LightningTestModel, LightTestFitSingleTestDataloadersMixin ): pass hparams = tutils.get_default_hparams() model = CurrentTestModel(hparams) # fit model with pytest.raises(MisconfigurationException): trainer = Trainer( default_save_path=tmpdir, max_epochs=1, test_percent_check=0.5 ) trainer.test(model) # logger file to get meta trainer = Trainer( default_save_path=tmpdir, max_epochs=1 ) result = trainer.fit(model) trainer.test(model) # verify training completed assert result == 1