From a7705c8677b9e2b5105e26a38db9d1b650182576 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Aug 2020 12:28:00 -0400 Subject: [PATCH] expand eval loop out (#3165) --- pytorch_lightning/trainer/evaluation_loop.py | 66 +++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index dbb54614ec..f9bc5476d7 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -325,8 +325,70 @@ class TrainerEvaluationLoopMixin(ABC): # TODO: deprecate self.evaluation_loop.on_evaluation_start() - # run evaluation (val_step + val_step_end + val_epoch_end) - eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) + # ------------------------------ + # ------------------------------ + # ------------------------------ + # enable eval mode + no grads + model.zero_grad() + model.eval() + torch.set_grad_enabled(False) + + # set up the eval loop + self.evaluation_loop.setup(model, max_batches, dataloaders) + + # hook + self.evaluation_loop.on_evaluation_epoch_start() + + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + dl_outputs = [] + + # certain accelerators need to process the dataloader + dataloader = self.accelerator_backend.process_dataloader(dataloader) + + # each dataloader has a max num batches + dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] + + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue + + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break + + # hook + self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + + # lightning module methods + output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) + output = self.evaluation_loop.evaluation_step_end(output) + + # hook + self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx) + + # clean up + self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) + self.evaluation_loop.log_step_metrics(output, batch_idx) + + # track epoch level metrics + if output is not None: + dl_outputs.append(output) + + self.evaluation_loop.outputs.append(dl_outputs) + + # lightning module method + eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) + + # hook + self.evaluation_loop.on_evaluation_epoch_end(eval_results) + + # enable train mode again + model.train() + torch.set_grad_enabled(True) + # ------------------------------ + # ------------------------------ + # ------------------------------ # log the final eval loop metrics eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode)