ref: moved hooks around in eval loop (#3195)

* moved hooks around in eval loop

* moved hooks around in eval loop

* moved hooks around in eval loop

* moved hooks around in eval loop
This commit is contained in:
William Falcon 2020-08-26 08:45:15 -04:00 committed by GitHub
parent 888340d17e
commit bd35c869ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 34 deletions

View File

@ -249,13 +249,7 @@ class EvaluationLoop(object):
# track debug metrics
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)
def on_evaluation_epoch_end(self, eval_results, *args, **kwargs):
# log epoch level metrics
self.log_epoch_metrics(eval_results)
# Write predictions to disk if they're available
self.predictions.to_disk()
def on_evaluation_epoch_end(self, *args, **kwargs):
# call the callback hook
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)

View File

@ -298,8 +298,12 @@ class TrainerEvaluationLoopMixin(ABC):
# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
# log epoch level metrics
self.evaluation_loop.log_epoch_metrics(eval_results)
self.evaluation_loop.predictions.to_disk()
# hook
self.evaluation_loop.on_evaluation_epoch_end(eval_results)
self.evaluation_loop.on_evaluation_epoch_end()
# enable train mode again
model.train()
@ -308,44 +312,35 @@ class TrainerEvaluationLoopMixin(ABC):
return eval_results
def run_evaluation(self, test_mode: bool = False):
# set up the loop for val/test
# bookkeeping
self.evaluation_loop.testing = test_mode
# TODO: deprecate
model = self.get_model()
# select dataloaders
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders()
# enable disabling validation step with limit_val_batches = 0
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
return [], []
# enable eval mode + no grads
model = self.get_model()
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)
# hook
self.evaluation_loop.on_evaluation_start()
# ------------------------------
# ------------------------------
# ------------------------------
# enable eval mode + no grads
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)
# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)
# hook
self.evaluation_loop.on_evaluation_epoch_start()
# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
# hook
self.evaluation_loop.on_evaluation_epoch_start()
# bookkeeping
dl_outputs = []
# certain accelerators need to process the dataloader
dataloader = self.accelerator_backend.process_dataloader(dataloader)
# each dataloader has a max num batches
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
for batch_idx, batch in enumerate(dataloader):
@ -379,12 +374,13 @@ class TrainerEvaluationLoopMixin(ABC):
# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
# hook
self.evaluation_loop.on_evaluation_epoch_end(eval_results)
# bookkeeping
self.evaluation_loop.log_epoch_metrics(eval_results)
self.evaluation_loop.predictions.to_disk()
# hook
self.evaluation_loop.on_evaluation_epoch_end()
# enable train mode again
model.train()
torch.set_grad_enabled(True)
# ------------------------------
# ------------------------------
# ------------------------------
@ -396,7 +392,11 @@ class TrainerEvaluationLoopMixin(ABC):
if self.reload_dataloaders_every_epoch:
self.evaluation_loop.reload_evaluation_dataloaders()
# TODO: deprecate
# enable train mode again
model.train()
torch.set_grad_enabled(True)
# hook
self.evaluation_loop.on_evaluation_end()
return eval_loop_results, eval_results