From 1077159834199a9fc06d6c4f21f551a180c3e75a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 4 May 2020 17:38:08 +0200 Subject: [PATCH] Tests: refactor models (#1691) * refactor default model * drop redundant seeds * drop redundant seeds * refactor models tests * refactor models tests * imports * fix conf * Apply suggestions from code review --- pytorch_lightning/core/lightning.py | 4 +- tests/callbacks/test_callbacks.py | 3 +- tests/models/test_amp.py | 6 +- tests/models/test_cpu.py | 26 ++------- tests/models/test_gpu.py | 6 +- tests/models/test_hooks.py | 16 +----- tests/models/test_horovod.py | 5 +- tests/models/test_restore.py | 87 ++++++++++++++++------------- 8 files changed, 67 insertions(+), 86 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2f1de6412f..a534929434 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1535,8 +1535,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams else: rank_zero_warn( - f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ " - f"contains argument 'hparams'. Will pass in an empty Namespace instead." + f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__" + " contains argument 'hparams'. Will pass in an empty Namespace instead." " Did you forget to store your model hyperparameters in self.hparams?" ) hparams = Namespace() diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 2bbcfaea1f..b1fb71dc8e 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -249,8 +249,7 @@ def test_pickling(tmpdir): @pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): - """ Test that None in checkpoint callback is valid and that chkp_path is - set correctly """ + """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ tutils.reset_seed() class CurrentTestModel(LightTrainDataloader, TestModelBase): diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index f4f1d9c20a..52fb90f135 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -6,7 +6,7 @@ import torch import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import LightningTestModel, EvalModelTemplate +from tests.base import EvalModelTemplate @pytest.mark.spawn @@ -15,7 +15,6 @@ from tests.base import LightningTestModel, EvalModelTemplate def test_amp_single_gpu(tmpdir, backend): """Make sure DP/DDP + AMP work.""" tutils.reset_seed() - trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -63,8 +62,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): tutils.set_random_master_port() os.environ['SLURM_LOCALID'] = str(0) - hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) # exp file to get meta logger = tutils.get_default_logger(tmpdir) diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 46d1ba6e44..13120c0175 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -7,16 +7,8 @@ from packaging.version import parse as version_parse import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ( - EarlyStopping, -) -from tests.base import ( - TestModelBase, - LightTrainDataloader, - LightningTestModel, - LightTestMixin, - EvalModelTemplate, -) +from pytorch_lightning.callbacks import EarlyStopping +from tests.base import EvalModelTemplate def test_early_stopping_cpu_model(tmpdir): @@ -106,8 +98,7 @@ def test_default_logger_callbacks_cpu_model(tmpdir): def test_running_test_after_fitting(tmpdir): """Verify test() on fitted model.""" - hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -138,11 +129,7 @@ def test_running_test_after_fitting(tmpdir): def test_running_test_no_val(tmpdir): """Verify `test()` works on a model with no `val_loader`.""" - class CurrentTestModel(LightTrainDataloader, LightTestMixin, TestModelBase): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -220,8 +207,7 @@ def test_single_gpu_batch_parse(): def test_simple_cpu(tmpdir): """Verify continue training session on CPU.""" - hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) # fit model trainer = Trainer( @@ -285,7 +271,7 @@ def test_tbptt_cpu_model(tmpdir): def __len__(self): return 1 - class BpttTestModel(LightTrainDataloader, TestModelBase): + class BpttTestModel(EvalModelTemplate): def __init__(self, hparams): super().__init__(hparams) self.test_hidden = None diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index dbaf4db8f8..5bdb603e14 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core import memory from pytorch_lightning.trainer.distrib_parts import parse_gpu_ids, determine_root_gpu_device from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import LightningTestModel, EvalModelTemplate +from tests.base import EvalModelTemplate PRETEND_N_OF_GPUS = 16 @@ -65,7 +65,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(hparams) # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -112,7 +112,7 @@ def test_cpu_slurm_save_load(tmpdir): logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir), ) - model = LightningTestModel(hparams) + model = EvalModelTemplate(hparams) # set the epoch start hook so we can predict before the model does the full training def assert_pred_same(): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1d0e55df40..00147ef2bc 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -2,29 +2,19 @@ import pytest import tests.base.utils as tutils from pytorch_lightning import Trainer -from tests.base import ( - LightTrainDataloader, - LightValidationMixin, - TestModelBase, - LightTestMixin) +from tests.base import EvalModelTemplate @pytest.mark.parametrize('max_steps', [1, 2, 3]) def test_on_before_zero_grad_called(max_steps): - class CurrentTestModel( - LightTrainDataloader, - LightValidationMixin, - LightTestMixin, - TestModelBase, - ): + class CurrentTestModel(EvalModelTemplate): on_before_zero_grad_called = 0 def on_before_zero_grad(self, optimizer): self.on_before_zero_grad_called += 1 - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) + model = CurrentTestModel(tutils.get_default_hparams()) trainer = Trainer( max_steps=max_steps, diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 0f41dee6e4..14644aee66 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -11,7 +11,7 @@ import torch from pytorch_lightning import Trainer import tests.base.utils as tutils -from tests.base import LightningTestModel +from tests.base import EvalModelTemplate from tests.base.models import TestGAN try: @@ -107,7 +107,8 @@ def test_horovod_multi_gpu(tmpdir): @pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_horovod_transfer_batch_to_gpu(tmpdir): - class TestTrainingStepModel(LightningTestModel): + + class TestTrainingStepModel(EvalModelTemplate): def training_step(self, batch, *args, **kwargs): x, y = batch assert str(x.device) != 'cpu' diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index af0165d498..0a927a3a94 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -9,11 +9,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import ( - LightningTestModel, - LightningTestModelWithoutHyperparametersArg, - LightningTestModelWithUnusedHyperparametersArg -) +from tests.base import EvalModelTemplate @pytest.mark.spawn @@ -23,8 +19,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend): """Verify `test()` on pretrained model.""" tutils.set_random_master_port() - hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) # exp file to get meta logger = tutils.get_default_logger(tmpdir) @@ -53,7 +48,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend): assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model(logger, trainer.checkpoint_callback.dirpath, - module_class=LightningTestModel) + module_class=EvalModelTemplate) # run test set new_trainer = Trainer(**trainer_options) @@ -72,8 +67,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend): def test_running_test_pretrained_model_cpu(tmpdir): """Verify test() on pretrained model.""" - hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -97,7 +91,7 @@ def test_running_test_pretrained_model_cpu(tmpdir): # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model( - logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel + logger, trainer.checkpoint_callback.dirpath, module_class=EvalModelTemplate ) new_trainer = Trainer(**trainer_options) @@ -110,7 +104,7 @@ def test_running_test_pretrained_model_cpu(tmpdir): def test_load_model_from_checkpoint(tmpdir): """Verify test() on pretrained model.""" hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(hparams) trainer_options = dict( progress_bar_refresh_rate=0, @@ -131,7 +125,7 @@ def test_load_model_from_checkpoint(tmpdir): # load last checkpoint last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] - pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint) + pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint) # test that hparams loaded correctly for k, v in vars(hparams).items(): @@ -152,7 +146,13 @@ def test_load_model_from_checkpoint(tmpdir): def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(hparams) + + trainer_options = dict( + max_epochs=1, + gpus=2, + distributed_backend='dp', + ) # get logger logger = tutils.get_default_logger(tmpdir) @@ -161,13 +161,9 @@ def test_dp_resume(tmpdir): # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) - trainer_options = dict( - max_epochs=1, - gpus=2, - distributed_backend='dp', - logger=logger, - checkpoint_callback=checkpoint, - ) + # add these to the trainer options + trainer_options['logger'] = logger + trainer_options['checkpoint_callback'] = checkpoint # fit model trainer = Trainer(**trainer_options) @@ -188,13 +184,11 @@ def test_dp_resume(tmpdir): # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) - trainer_options.update( - logger=new_logger, - checkpoint_callback=ModelCheckpoint(tmpdir), - train_percent_check=0.5, - val_percent_check=0.2, - max_epochs=1, - ) + trainer_options['logger'] = new_logger + trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir) + trainer_options['train_percent_check'] = 0.5 + trainer_options['val_percent_check'] = 0.2 + trainer_options['max_epochs'] = 1 new_trainer = Trainer(**trainer_options) # set the epoch start hook so we can predict before the model does the full training @@ -210,7 +204,7 @@ def test_dp_resume(tmpdir): tutils.run_prediction(dataloader, dp_model, dp=True) # new model - model = LightningTestModel(hparams) + model = EvalModelTemplate(hparams) model.on_train_start = assert_good_acc # fit new model which should load hpc weights @@ -223,18 +217,19 @@ def test_dp_resume(tmpdir): def test_model_saving_loading(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" - hparams = tutils.get_default_hparams() - model = LightningTestModel(hparams) + model = EvalModelTemplate(tutils.get_default_hparams()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) - # fit model - trainer = Trainer( + trainer_options = dict( max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir) ) + + # fit model + trainer = Trainer(**trainer_options) result = trainer.fit(model) # traning complete @@ -263,7 +258,7 @@ def test_model_saving_loading(tmpdir): # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') - model_2 = LightningTestModel.load_from_checkpoint( + model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, tags_csv=tags_path ) @@ -276,8 +271,7 @@ def test_model_saving_loading(tmpdir): def test_load_model_with_missing_hparams(tmpdir): - # fit model - trainer = Trainer( + trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=1, checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), @@ -285,22 +279,35 @@ def test_load_model_with_missing_hparams(tmpdir): default_root_dir=tmpdir, ) - model = LightningTestModelWithoutHyperparametersArg() + # fit model + trainer = Trainer(**trainer_options) + + class CurrentModelWithoutHparams(EvalModelTemplate): + def __init__(self): + hparams = tutils.get_default_hparams() + super().__init__(hparams) + + class CurrentModelUnusedHparams(EvalModelTemplate): + def __init__(self, hparams): + hparams = tutils.get_default_hparams() + super().__init__(hparams) + + model = CurrentModelWithoutHparams() trainer.fit(model) last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] # try to load a checkpoint that has hparams but model is missing hparams arg with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"): - LightningTestModelWithoutHyperparametersArg.load_from_checkpoint(last_checkpoint) + CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint) # create a checkpoint without hyperparameters # if the model does not take a hparams argument, it should not throw an error ckpt = torch.load(last_checkpoint) del(ckpt['hparams']) torch.save(ckpt, last_checkpoint) - LightningTestModelWithoutHyperparametersArg.load_from_checkpoint(last_checkpoint) + CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint) # load checkpoint without hparams again # warn if user's model has hparams argument with pytest.warns(UserWarning, match=r".*Will pass in an empty Namespace instead."): - LightningTestModelWithUnusedHyperparametersArg.load_from_checkpoint(last_checkpoint) + CurrentModelUnusedHparams.load_from_checkpoint(last_checkpoint)