This commit is contained in:
William Falcon 2020-10-04 23:02:35 -04:00 committed by GitHub
parent 97e62b38cf
commit f58c760409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 2 deletions

View File

@ -140,6 +140,22 @@ class ModelHooks:
"""
# do something when the batch ends
def on_validation_model_eval(
self
) -> None:
"""
Sets the model to eval during the val loop
"""
self.eval()
def on_validation_model_train(
self
) -> None:
"""
Sets the model to train during the val loop
"""
self.train()
def on_validation_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
@ -192,6 +208,22 @@ class ModelHooks:
"""
# do something when the batch ends
def on_test_model_eval(
self
) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()
def on_test_model_train(
self
) -> None:
"""
Sets the model to train during the test loop
"""
self.train()
def on_batch_start(self, batch: Any) -> None:
"""
Called in the training loop before anything happens for that batch.

View File

@ -88,6 +88,20 @@ class EvaluationLoop(object):
else:
self.trainer.call_hook('on_validation_start', *args, **kwargs)
def on_evaluation_model_eval(self, *args, **kwargs):
model_ref = self.trainer.get_model()
if self.testing:
model_ref.on_test_model_eval()
else:
model_ref.on_validation_model_eval()
def on_evaluation_model_train(self, *args, **kwargs):
model_ref = self.trainer.get_model()
if self.testing:
model_ref.on_test_model_train()
else:
model_ref.on_validation_model_train()
def on_evaluation_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_end', *args, **kwargs)

View File

@ -543,8 +543,9 @@ class Trainer(
# enable eval mode + no grads
model = self.get_model()
self.evaluation_loop.on_evaluation_model_eval()
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)
# hook
@ -615,7 +616,7 @@ class Trainer(
self.evaluation_loop.on_evaluation_epoch_end()
# enable train mode again
model.train()
self.evaluation_loop.on_evaluation_model_train()
torch.set_grad_enabled(True)
# hook

View File

View File

@ -0,0 +1,35 @@
from tests.base.boring_model import BoringModel
from pytorch_lightning import Trainer
from unittest import mock
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval')
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train')
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval')
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_train')
def test_eval_train_calls(test_train_mock, test_eval_mock, val_train_mock, val_eval_mock, tmpdir):
"""
Tests that only training_step can be used
"""
model = BoringModel()
model.validation_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)
trainer.fit(model)
trainer.test()
# sanity + 2 epochs
assert val_eval_mock.call_count == 3
assert val_train_mock.call_count == 3
# test is called only once
assert test_eval_mock.call_count == 1
assert test_train_mock.call_count == 1