ref: moved hooks around in eval loop (#3195)
* moved hooks around in eval loop * moved hooks around in eval loop * moved hooks around in eval loop * moved hooks around in eval loop
This commit is contained in:
parent
888340d17e
commit
bd35c869ee
|
@ -249,13 +249,7 @@ class EvaluationLoop(object):
|
|||
# track debug metrics
|
||||
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)
|
||||
|
||||
def on_evaluation_epoch_end(self, eval_results, *args, **kwargs):
|
||||
# log epoch level metrics
|
||||
self.log_epoch_metrics(eval_results)
|
||||
|
||||
# Write predictions to disk if they're available
|
||||
self.predictions.to_disk()
|
||||
|
||||
def on_evaluation_epoch_end(self, *args, **kwargs):
|
||||
# call the callback hook
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
|
||||
|
|
|
@ -298,8 +298,12 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# lightning module method
|
||||
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
|
||||
|
||||
# log epoch level metrics
|
||||
self.evaluation_loop.log_epoch_metrics(eval_results)
|
||||
self.evaluation_loop.predictions.to_disk()
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_end(eval_results)
|
||||
self.evaluation_loop.on_evaluation_epoch_end()
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
|
@ -308,44 +312,35 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
return eval_results
|
||||
|
||||
def run_evaluation(self, test_mode: bool = False):
|
||||
# set up the loop for val/test
|
||||
# bookkeeping
|
||||
self.evaluation_loop.testing = test_mode
|
||||
|
||||
# TODO: deprecate
|
||||
model = self.get_model()
|
||||
|
||||
# select dataloaders
|
||||
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders()
|
||||
|
||||
# enable disabling validation step with limit_val_batches = 0
|
||||
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
|
||||
return [], []
|
||||
|
||||
# enable eval mode + no grads
|
||||
model = self.get_model()
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_start()
|
||||
|
||||
# ------------------------------
|
||||
# ------------------------------
|
||||
# ------------------------------
|
||||
# enable eval mode + no grads
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# set up the eval loop
|
||||
self.evaluation_loop.setup(model, max_batches, dataloaders)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_start()
|
||||
|
||||
# run validation/testing
|
||||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_start()
|
||||
|
||||
# bookkeeping
|
||||
dl_outputs = []
|
||||
|
||||
# certain accelerators need to process the dataloader
|
||||
dataloader = self.accelerator_backend.process_dataloader(dataloader)
|
||||
|
||||
# each dataloader has a max num batches
|
||||
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
|
@ -379,12 +374,13 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# lightning module method
|
||||
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_end(eval_results)
|
||||
# bookkeeping
|
||||
self.evaluation_loop.log_epoch_metrics(eval_results)
|
||||
self.evaluation_loop.predictions.to_disk()
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_end()
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
# ------------------------------
|
||||
# ------------------------------
|
||||
# ------------------------------
|
||||
|
@ -396,7 +392,11 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
if self.reload_dataloaders_every_epoch:
|
||||
self.evaluation_loop.reload_evaluation_dataloaders()
|
||||
|
||||
# TODO: deprecate
|
||||
# enable train mode again
|
||||
model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_end()
|
||||
|
||||
return eval_loop_results, eval_results
|
||||
|
|
Loading…
Reference in New Issue