From f58c7604093fc37c765ac88e46aaf52b403332fe Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 4 Oct 2020 23:02:35 -0400 Subject: [PATCH] Fixes #2551 (#3858) --- pytorch_lightning/core/hooks.py | 32 +++++++++++++++++ pytorch_lightning/trainer/evaluation_loop.py | 14 ++++++++ pytorch_lightning/trainer/trainer.py | 5 +-- tests/trainer/model_hooks/__init__.py | 0 tests/trainer/model_hooks/test_model_hooks.py | 35 +++++++++++++++++++ 5 files changed, 84 insertions(+), 2 deletions(-) create mode 100644 tests/trainer/model_hooks/__init__.py create mode 100644 tests/trainer/model_hooks/test_model_hooks.py diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 5245d90606..9413c01954 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -140,6 +140,22 @@ class ModelHooks: """ # do something when the batch ends + def on_validation_model_eval( + self + ) -> None: + """ + Sets the model to eval during the val loop + """ + self.eval() + + def on_validation_model_train( + self + ) -> None: + """ + Sets the model to train during the val loop + """ + self.train() + def on_validation_batch_start( self, batch: Any, batch_idx: int, dataloader_idx: int ) -> None: @@ -192,6 +208,22 @@ class ModelHooks: """ # do something when the batch ends + def on_test_model_eval( + self + ) -> None: + """ + Sets the model to eval during the test loop + """ + self.eval() + + def on_test_model_train( + self + ) -> None: + """ + Sets the model to train during the test loop + """ + self.train() + def on_batch_start(self, batch: Any) -> None: """ Called in the training loop before anything happens for that batch. diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 6f955a316b..5e04173e6d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -88,6 +88,20 @@ class EvaluationLoop(object): else: self.trainer.call_hook('on_validation_start', *args, **kwargs) + def on_evaluation_model_eval(self, *args, **kwargs): + model_ref = self.trainer.get_model() + if self.testing: + model_ref.on_test_model_eval() + else: + model_ref.on_validation_model_eval() + + def on_evaluation_model_train(self, *args, **kwargs): + model_ref = self.trainer.get_model() + if self.testing: + model_ref.on_test_model_train() + else: + model_ref.on_validation_model_train() + def on_evaluation_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 87f85f9385..219df9f673 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -543,8 +543,9 @@ class Trainer( # enable eval mode + no grads model = self.get_model() + self.evaluation_loop.on_evaluation_model_eval() + model.zero_grad() - model.eval() torch.set_grad_enabled(False) # hook @@ -615,7 +616,7 @@ class Trainer( self.evaluation_loop.on_evaluation_epoch_end() # enable train mode again - model.train() + self.evaluation_loop.on_evaluation_model_train() torch.set_grad_enabled(True) # hook diff --git a/tests/trainer/model_hooks/__init__.py b/tests/trainer/model_hooks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/trainer/model_hooks/test_model_hooks.py b/tests/trainer/model_hooks/test_model_hooks.py new file mode 100644 index 0000000000..67167f5129 --- /dev/null +++ b/tests/trainer/model_hooks/test_model_hooks.py @@ -0,0 +1,35 @@ +from tests.base.boring_model import BoringModel +from pytorch_lightning import Trainer +from unittest import mock + + +@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval') +@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train') +@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval') +@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_train') +def test_eval_train_calls(test_train_mock, test_eval_mock, val_train_mock, val_eval_mock, tmpdir): + """ + Tests that only training_step can be used + """ + model = BoringModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + row_log_interval=1, + weights_summary=None, + ) + + trainer.fit(model) + trainer.test() + + # sanity + 2 epochs + assert val_eval_mock.call_count == 3 + assert val_train_mock.call_count == 3 + + # test is called only once + assert test_eval_mock.call_count == 1 + assert test_train_mock.call_count == 1