[Fix] Ensure we set the eval/train flag correctly on accelerator model (#6877)
* Ensure we move the model to eval mode before running evaluation * Ensure we set the flag appropriately across all stages * Add test, move hooks logic * Apply same fix to the validate loop * Update pytorch_lightning/trainer/trainer.py * Fix function name * Fix order, add predict * Shorten the name * Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function * Use hook, remove double call
This commit is contained in:
parent
851fd7fae7
commit
742c48e994
|
@ -114,13 +114,13 @@ class ModelHooks:
|
|||
"""
|
||||
Sets the model to eval during the val loop
|
||||
"""
|
||||
self.eval()
|
||||
self.trainer.model.eval()
|
||||
|
||||
def on_validation_model_train(self) -> None:
|
||||
"""
|
||||
Sets the model to train during the val loop
|
||||
"""
|
||||
self.train()
|
||||
self.trainer.model.train()
|
||||
|
||||
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
"""
|
||||
|
@ -172,19 +172,19 @@ class ModelHooks:
|
|||
"""
|
||||
Sets the model to train during the test loop
|
||||
"""
|
||||
self.train()
|
||||
self.trainer.model.train()
|
||||
|
||||
def on_test_model_eval(self) -> None:
|
||||
"""
|
||||
Sets the model to eval during the test loop
|
||||
"""
|
||||
self.eval()
|
||||
self.trainer.model.eval()
|
||||
|
||||
def on_predict_model_eval(self) -> None:
|
||||
"""
|
||||
Sets the model to eval during the predict loop
|
||||
"""
|
||||
self.eval()
|
||||
self.trainer.model.eval()
|
||||
|
||||
def on_epoch_start(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -44,7 +44,6 @@ class PredictLoop(object):
|
|||
model_ref.on_predict_model_eval()
|
||||
|
||||
def setup(self, model, max_batches, dataloaders):
|
||||
self.trainer.call_hook("on_predict_start")
|
||||
|
||||
# copy properties for forward overrides
|
||||
self.trainer.model_connector.copy_trainer_model_properties(model)
|
||||
|
|
|
@ -582,11 +582,11 @@ class Trainer(
|
|||
self.checkpoint_connector.has_trained = False
|
||||
|
||||
# enable train mode
|
||||
model = self.lightning_module
|
||||
model.train()
|
||||
self.model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
# reload data when needed
|
||||
model = self.lightning_module
|
||||
self.train_loop.reset_train_val_dataloaders(model)
|
||||
|
||||
# hook
|
||||
|
@ -772,8 +772,6 @@ class Trainer(
|
|||
return eval_loop_results
|
||||
|
||||
def run_predict(self):
|
||||
self.predict_loop.on_predict_start()
|
||||
|
||||
# prepare dataloaders
|
||||
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()
|
||||
|
||||
|
@ -789,6 +787,9 @@ class Trainer(
|
|||
model.zero_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# call hook
|
||||
self.predict_loop.on_predict_start()
|
||||
|
||||
# set up the eval loop
|
||||
self.predict_loop.setup(model, max_batches, dataloaders)
|
||||
|
||||
|
|
|
@ -1438,7 +1438,9 @@ def test_trainer_setup_call(tmpdir, stage):
|
|||
)
|
||||
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics")
|
||||
def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
self.log("foo", -1)
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
@ -1888,3 +1890,33 @@ def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
|
|||
trainer.validate()
|
||||
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
|
||||
trainer.test()
|
||||
|
||||
|
||||
class TrainerStagesModel(BoringModel):
|
||||
|
||||
def on_train_start(self) -> None:
|
||||
assert self.trainer.model.training
|
||||
assert self.training
|
||||
|
||||
def on_validation_start(self) -> None:
|
||||
assert not self.trainer.model.training
|
||||
assert not self.training
|
||||
|
||||
def on_test_start(self) -> None:
|
||||
assert not self.trainer.model.training
|
||||
assert not self.training
|
||||
|
||||
def on_predict_start(self) -> None:
|
||||
assert not self.trainer.model.training
|
||||
assert not self.training
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['accelerator', 'num_processes'],
|
||||
[(None, 1), pytest.param('ddp', 2, marks=RunIf(skip_windows=True))])
|
||||
def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes):
|
||||
model = TrainerStagesModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model, model.val_dataloader())
|
||||
|
|
Loading…
Reference in New Issue