diff --git a/pytorch_lightning/trainer/evaluate_loop.py b/pytorch_lightning/trainer/evaluate_loop.py index 4ce1bed0ba..d197ff8e54 100644 --- a/pytorch_lightning/trainer/evaluate_loop.py +++ b/pytorch_lightning/trainer/evaluate_loop.py @@ -13,6 +13,54 @@ class EvaluationLoop(object): self.predictions = None self.max_batches = None + def get_evaluation_dataloaders(self): + # select dataloaders + model = self.trainer.get_model() + + if self.testing: + self.trainer.reset_test_dataloader(model) + dataloaders = self.trainer.test_dataloaders + max_batches = self.trainer.num_test_batches + else: + if self.trainer.val_dataloaders is None: + self.trainer.reset_val_dataloader(model) + + dataloaders = self.trainer.val_dataloaders + max_batches = self.trainer.num_val_batches + + return dataloaders, max_batches + + def should_skip_evaluation(self, dataloaders, max_batches): + # skip when dataloaders aren't defined + if dataloaders is None: + return True + + # enable disabling validation step with limit_val_batches = 0 + should_skip = sum(max_batches) == 0 + if should_skip: + return True + + return False + + def on_evaluation_start(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_start', *args, **kwargs) + + def on_evaluation_end(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_end', *args, **kwargs) + + def reload_evaluation_dataloaders(self): + model = self.trainer.get_model() + if self.testing: + self.trainer.reset_test_dataloader(model) + else: + self.trainer.reset_val_dataloader(model) + def is_using_eval_results(self): outputs = self.outputs using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2bd1cd9c41..cd136465d6 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -245,8 +245,6 @@ class TrainerEvaluationLoopMixin(ABC): entry is the number of batches to process in the corresponding dataloader. test_mode: """ - # set up the loop for val/test - self.evaluation_loop.testing = test_mode # enable eval mode + no grads model.zero_grad() @@ -310,7 +308,10 @@ class TrainerEvaluationLoopMixin(ABC): return eval_results def run_evaluation(self, test_mode: bool = False): - # hook + # set up the loop for val/test + self.evaluation_loop.testing = test_mode + + # TODO: deprecate model = self.get_model() model.on_pre_performance_check() @@ -328,20 +329,13 @@ class TrainerEvaluationLoopMixin(ABC): dataloaders = self.val_dataloaders max_batches = self.num_val_batches - if dataloaders is None: - return [], [] - - # Validation/Test begin callbacks - if test_mode: - self.on_test_start() - else: - self.on_validation_start() - # enable disabling validation step with limit_val_batches = 0 - should_skip = sum(max_batches) == 0 - if should_skip: + if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): return [], [] + # TODO: deprecate + self.evaluation_loop.on_evaluation_start() + # run evaluation (val_step + val_step_end + val_epoch_end) eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) @@ -351,20 +345,12 @@ class TrainerEvaluationLoopMixin(ABC): # hook model.on_post_performance_check() - # eventual dataset reloading - if test_mode: - if self.reload_dataloaders_every_epoch: - self.reset_test_dataloader(model) - else: - # val - if self.reload_dataloaders_every_epoch: - self.reset_val_dataloader(model) + # user may want to reload every epoch + if self.reload_dataloaders_every_epoch: + self.evaluation_loop.reload_evaluation_dataloaders() - # Validation/Test end callbacks - if test_mode: - self.on_test_end() - else: - self.on_validation_end() + # TODO: deprecate + self.evaluation_loop.on_evaluation_end() return eval_loop_results, eval_results