From c556ee61198e1c4ef48fdb0f2048143e76620e69 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Aug 2020 15:48:14 -0400 Subject: [PATCH] ref: refactor eval loop to use hooks. use test_mode for if so we can split later (#3129) * moved eval hooks * moved eval hooks * moved eval hooks * moved eval hooks --- pytorch_lightning/trainer/evaluation_loop.py | 70 ++++++-------------- pytorch_lightning/trainer/trainer.py | 18 +++-- 2 files changed, 34 insertions(+), 54 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 9f34eb4edd..f535779769 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -225,27 +225,9 @@ class TrainerEvaluationLoopMixin(ABC): def reset_val_dataloader(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - def __call_eval_loop_hook_start(self, test_mode): - """on_validation/test_epoch_start""" - self.__call_eval_loop_hook_evt(test_mode, 'start') - - def __call_eval_loop_hook_end(self, test_mode): - """on_validation/test_epoch_end""" - self.__call_eval_loop_hook_evt(test_mode, 'end') - - def __call_eval_loop_hook_evt(self, test_mode, epoch_event): - model = self.get_model() - - # on_[train/validation]_epoch_start hook - hook_root_name = 'test' if test_mode else 'validation' - hook_name = f'on_{hook_root_name}_epoch_{epoch_event}' - with self.profiler.profile(hook_name): - # call hook - getattr(self, hook_name)() - - # model hooks - if self.is_function_implemented(hook_name): - getattr(model, hook_name)() + @abstractmethod + def call_hook(self, hook_name, *args, **kwargs): + """Warning: this is just empty shell for code implemented in other class.""" def _evaluate( self, @@ -284,7 +266,10 @@ class TrainerEvaluationLoopMixin(ABC): # -------------------------- # ON_EVAL_EPOCH_START hook # -------------------------- - self.__call_eval_loop_hook_start(test_mode) + if test_mode: + self.call_hook('on_test_epoch_start') + else: + self.call_hook('on_validation_epoch_start') # run validation for dataloader_idx, dataloader in enumerate(dataloaders): @@ -309,17 +294,10 @@ class TrainerEvaluationLoopMixin(ABC): # callbacks if test_mode: - self.on_test_batch_start(batch, batch_idx, dataloader_idx) - if self.is_overridden('on_test_batch_start'): - model_ref = self.get_model() - with self.profiler.profile('on_test_batch_start'): - model_ref.on_test_batch_start(batch, batch_idx, dataloader_idx) + self.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: - self.on_validation_batch_start(batch, batch_idx, dataloader_idx) - if self.is_overridden('on_validation_batch_start'): - model_ref = self.get_model() - with self.profiler.profile('on_validation_batch_start'): - model_ref.on_validation_batch_start(batch, batch_idx, dataloader_idx) + self.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) + # ----------------- # RUN EVALUATION STEP # ----------------- @@ -353,12 +331,11 @@ class TrainerEvaluationLoopMixin(ABC): # EVAL STEP END # ------------------ # on dp / ddp2 might still want to do something with the batch parts - eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end' - if self.is_overridden(eval_step_end_hook_name): - model_ref = self.get_model() - with self.profiler.profile(eval_step_end_hook_name): - eval_step_end = getattr(model_ref, eval_step_end_hook_name) - output = eval_step_end(output) + if self.is_overridden('test_step_end') or self.is_overridden('validation_step_end'): + if test_mode: + output = self.call_hook('test_step_end', output) + else: + output = self.call_hook('validation_step_end', output) elif is_result_obj and (self.use_dp or self.use_ddp2): # result auto reduce @@ -366,17 +343,9 @@ class TrainerEvaluationLoopMixin(ABC): # callbacks (on __batch_end) if test_mode: - self.on_test_batch_end(batch, batch_idx, dataloader_idx) - if self.is_overridden('on_test_batch_end'): - model_ref = self.get_model() - with self.profiler.profile('on_test_batch_end'): - model_ref.on_test_batch_end(batch, batch_idx, dataloader_idx) + self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx) else: - self.on_validation_batch_end(batch, batch_idx, dataloader_idx) - if self.is_overridden('on_validation_batch_end'): - model_ref = self.get_model() - with self.profiler.profile('on_validation_batch_end'): - model_ref.on_validation_batch_end(batch, batch_idx, dataloader_idx) + self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx) # track outputs for collation if output is not None: @@ -416,7 +385,10 @@ class TrainerEvaluationLoopMixin(ABC): # -------------------------- # ON_EVAL_EPOCH_END hook # -------------------------- - self.__call_eval_loop_hook_end(test_mode) + if test_mode: + self.call_hook('on_test_epoch_end') + else: + self.call_hook('on_validation_epoch_end') return eval_results diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 62e03e52fa..c4996cdb61 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1457,14 +1457,22 @@ class Trainer( self._setup_amp_backend(amp_type) def call_hook(self, hook_name, *args, **kwargs): - output = None - if self.is_overridden(hook_name): - model_ref = self.get_model() - with self.profiler.profile(hook_name): + # always profile hooks + with self.profiler.profile(hook_name): + + # first call trainer hook + if hasattr(self, hook_name): + trainer_hook = getattr(self, hook_name) + trainer_hook(*args, **kwargs) + + # next call hook in lightningModule + output = None + if self.is_overridden(hook_name): + model_ref = self.get_model() hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) - return output + return output class _PatchDataLoader(object):