parent
97e62b38cf
commit
f58c760409
|
@ -140,6 +140,22 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the batch ends
|
||||
|
||||
def on_validation_model_eval(
|
||||
self
|
||||
) -> None:
|
||||
"""
|
||||
Sets the model to eval during the val loop
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
def on_validation_model_train(
|
||||
self
|
||||
) -> None:
|
||||
"""
|
||||
Sets the model to train during the val loop
|
||||
"""
|
||||
self.train()
|
||||
|
||||
def on_validation_batch_start(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
|
@ -192,6 +208,22 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the batch ends
|
||||
|
||||
def on_test_model_eval(
|
||||
self
|
||||
) -> None:
|
||||
"""
|
||||
Sets the model to eval during the test loop
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
def on_test_model_train(
|
||||
self
|
||||
) -> None:
|
||||
"""
|
||||
Sets the model to train during the test loop
|
||||
"""
|
||||
self.train()
|
||||
|
||||
def on_batch_start(self, batch: Any) -> None:
|
||||
"""
|
||||
Called in the training loop before anything happens for that batch.
|
||||
|
|
|
@ -88,6 +88,20 @@ class EvaluationLoop(object):
|
|||
else:
|
||||
self.trainer.call_hook('on_validation_start', *args, **kwargs)
|
||||
|
||||
def on_evaluation_model_eval(self, *args, **kwargs):
|
||||
model_ref = self.trainer.get_model()
|
||||
if self.testing:
|
||||
model_ref.on_test_model_eval()
|
||||
else:
|
||||
model_ref.on_validation_model_eval()
|
||||
|
||||
def on_evaluation_model_train(self, *args, **kwargs):
|
||||
model_ref = self.trainer.get_model()
|
||||
if self.testing:
|
||||
model_ref.on_test_model_train()
|
||||
else:
|
||||
model_ref.on_validation_model_train()
|
||||
|
||||
def on_evaluation_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_end', *args, **kwargs)
|
||||
|
|
|
@ -543,8 +543,9 @@ class Trainer(
|
|||
|
||||
# enable eval mode + no grads
|
||||
model = self.get_model()
|
||||
self.evaluation_loop.on_evaluation_model_eval()
|
||||
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# hook
|
||||
|
@ -615,7 +616,7 @@ class Trainer(
|
|||
self.evaluation_loop.on_evaluation_epoch_end()
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
self.evaluation_loop.on_evaluation_model_train()
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
# hook
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
from tests.base.boring_model import BoringModel
|
||||
from pytorch_lightning import Trainer
|
||||
from unittest import mock
|
||||
|
||||
|
||||
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval')
|
||||
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train')
|
||||
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval')
|
||||
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_train')
|
||||
def test_eval_train_calls(test_train_mock, test_eval_mock, val_train_mock, val_eval_mock, tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
model = BoringModel()
|
||||
model.validation_epoch_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=2,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
trainer.test()
|
||||
|
||||
# sanity + 2 epochs
|
||||
assert val_eval_mock.call_count == 3
|
||||
assert val_train_mock.call_count == 3
|
||||
|
||||
# test is called only once
|
||||
assert test_eval_mock.call_count == 1
|
||||
assert test_train_mock.call_count == 1
|
Loading…
Reference in New Issue