diff --git a/pytorch_lightning/trainer/evaluate_loop.py b/pytorch_lightning/trainer/evaluate_loop.py index 4ed55b339a..83c95213b4 100644 --- a/pytorch_lightning/trainer/evaluate_loop.py +++ b/pytorch_lightning/trainer/evaluate_loop.py @@ -249,13 +249,7 @@ class EvaluationLoop(object): # track debug metrics self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) - def on_evaluation_epoch_end(self, eval_results, *args, **kwargs): - # log epoch level metrics - self.log_epoch_metrics(eval_results) - - # Write predictions to disk if they're available - self.predictions.to_disk() - + def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook if self.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index cc42e72e12..efe5c549db 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -298,8 +298,12 @@ class TrainerEvaluationLoopMixin(ABC): # lightning module method eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) + # log epoch level metrics + self.evaluation_loop.log_epoch_metrics(eval_results) + self.evaluation_loop.predictions.to_disk() + # hook - self.evaluation_loop.on_evaluation_epoch_end(eval_results) + self.evaluation_loop.on_evaluation_epoch_end() # enable train mode again model.train() @@ -308,44 +312,35 @@ class TrainerEvaluationLoopMixin(ABC): return eval_results def run_evaluation(self, test_mode: bool = False): - # set up the loop for val/test + # bookkeeping self.evaluation_loop.testing = test_mode - - # TODO: deprecate - model = self.get_model() - - # select dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() - - # enable disabling validation step with limit_val_batches = 0 if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): return [], [] + # enable eval mode + no grads + model = self.get_model() + model.zero_grad() + model.eval() + torch.set_grad_enabled(False) + # hook self.evaluation_loop.on_evaluation_start() # ------------------------------ # ------------------------------ # ------------------------------ - # enable eval mode + no grads - model.zero_grad() - model.eval() - torch.set_grad_enabled(False) - # set up the eval loop self.evaluation_loop.setup(model, max_batches, dataloaders) - # hook - self.evaluation_loop.on_evaluation_epoch_start() - # run validation/testing for dataloader_idx, dataloader in enumerate(dataloaders): + # hook + self.evaluation_loop.on_evaluation_epoch_start() + + # bookkeeping dl_outputs = [] - - # certain accelerators need to process the dataloader dataloader = self.accelerator_backend.process_dataloader(dataloader) - - # each dataloader has a max num batches dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): @@ -379,12 +374,13 @@ class TrainerEvaluationLoopMixin(ABC): # lightning module method eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) - # hook - self.evaluation_loop.on_evaluation_epoch_end(eval_results) + # bookkeeping + self.evaluation_loop.log_epoch_metrics(eval_results) + self.evaluation_loop.predictions.to_disk() + + # hook + self.evaluation_loop.on_evaluation_epoch_end() - # enable train mode again - model.train() - torch.set_grad_enabled(True) # ------------------------------ # ------------------------------ # ------------------------------ @@ -396,7 +392,11 @@ class TrainerEvaluationLoopMixin(ABC): if self.reload_dataloaders_every_epoch: self.evaluation_loop.reload_evaluation_dataloaders() - # TODO: deprecate + # enable train mode again + model.train() + torch.set_grad_enabled(True) + + # hook self.evaluation_loop.on_evaluation_end() return eval_loop_results, eval_results