[Fix] Ensure we set the eval/train flag correctly on accelerator model (#6877)

* Ensure we move the model to eval mode before running evaluation

* Ensure we set the flag appropriately across all stages

* Add test, move hooks logic

* Apply same fix to the validate loop

* Update pytorch_lightning/trainer/trainer.py

* Fix function name

* Fix order, add predict

* Shorten the name

* Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function

* Use hook, remove double call
This commit is contained in:
Sean Naren 2021-04-08 19:04:26 +01:00 committed by GitHub
parent 851fd7fae7
commit 742c48e994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 10 deletions

View File

@ -114,13 +114,13 @@ class ModelHooks:
"""
Sets the model to eval during the val loop
"""
self.eval()
self.trainer.model.eval()
def on_validation_model_train(self) -> None:
"""
Sets the model to train during the val loop
"""
self.train()
self.trainer.model.train()
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""
@ -172,19 +172,19 @@ class ModelHooks:
"""
Sets the model to train during the test loop
"""
self.train()
self.trainer.model.train()
def on_test_model_eval(self) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()
self.trainer.model.eval()
def on_predict_model_eval(self) -> None:
"""
Sets the model to eval during the predict loop
"""
self.eval()
self.trainer.model.eval()
def on_epoch_start(self) -> None:
"""

View File

@ -44,7 +44,6 @@ class PredictLoop(object):
model_ref.on_predict_model_eval()
def setup(self, model, max_batches, dataloaders):
self.trainer.call_hook("on_predict_start")
# copy properties for forward overrides
self.trainer.model_connector.copy_trainer_model_properties(model)

View File

@ -582,11 +582,11 @@ class Trainer(
self.checkpoint_connector.has_trained = False
# enable train mode
model = self.lightning_module
model.train()
self.model.train()
torch.set_grad_enabled(True)
# reload data when needed
model = self.lightning_module
self.train_loop.reset_train_val_dataloaders(model)
# hook
@ -772,8 +772,6 @@ class Trainer(
return eval_loop_results
def run_predict(self):
self.predict_loop.on_predict_start()
# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()
@ -789,6 +787,9 @@ class Trainer(
model.zero_grad()
torch.set_grad_enabled(False)
# call hook
self.predict_loop.on_predict_start()
# set up the eval loop
self.predict_loop.setup(model, max_batches, dataloaders)

View File

@ -1438,7 +1438,9 @@ def test_trainer_setup_call(tmpdir, stage):
)
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics")
def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):
class TestModel(BoringModel):
def training_step(self, *args, **kwargs):
self.log("foo", -1)
return super().training_step(*args, **kwargs)
@ -1888,3 +1890,33 @@ def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()
class TrainerStagesModel(BoringModel):
def on_train_start(self) -> None:
assert self.trainer.model.training
assert self.training
def on_validation_start(self) -> None:
assert not self.trainer.model.training
assert not self.training
def on_test_start(self) -> None:
assert not self.trainer.model.training
assert not self.training
def on_predict_start(self) -> None:
assert not self.trainer.model.training
assert not self.training
@pytest.mark.parametrize(['accelerator', 'num_processes'],
[(None, 1), pytest.param('ddp', 2, marks=RunIf(skip_windows=True))])
def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes):
model = TrainerStagesModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model, model.val_dataloader())