From e6bb26db1b917cae9d9d9c5af523b32a0e7bfa21 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 26 Aug 2020 13:53:23 -0400 Subject: [PATCH] ref: .fit hook clean up (#3198) * eval loop clean up * eval loop clean up * eval loop clean up * eval loop clean up --- pytorch_lightning/trainer/evaluation_loop.py | 9 +-------- pytorch_lightning/trainer/trainer.py | 14 +++++--------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c757f2c771..8b9d498717 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -245,14 +245,11 @@ class TrainerEvaluationLoopMixin(ABC): # hook self.evaluation_loop.on_evaluation_start() - # ------------------------------ - # ------------------------------ - # ------------------------------ # set up the eval loop self.evaluation_loop.setup(model, max_batches, dataloaders) # hook - # TODO: needs to move inside the loop but breaks early stopping + # TODO: should this be insider the dataloader loop? self.evaluation_loop.on_evaluation_epoch_start() # run validation/testing @@ -300,10 +297,6 @@ class TrainerEvaluationLoopMixin(ABC): # hook self.evaluation_loop.on_evaluation_epoch_end() - # ------------------------------ - # ------------------------------ - # ------------------------------ - # log the final eval loop metrics eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 90ab30e065..335521a179 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -995,10 +995,8 @@ class Trainer( # check that model is configured correctly self.config_validator.verify_loop_configurations(model) - # callbacks - self.on_fit_start(model) - if self.is_function_implemented('on_fit_start', model): - model.on_fit_start() + # hook + self.call_hook('on_fit_start', model) # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 @@ -1095,12 +1093,10 @@ class Trainer( self.accelerator_backend.setup(model) results = self.accelerator_backend.train(model) - # on fit end callback - self.on_fit_end() - if self.is_function_implemented('on_fit_end'): - model.on_fit_end() + # hook + self.call_hook('on_fit_end') - # teardown callback + # hook self.teardown('fit') if self.is_function_implemented('teardown'): model.teardown('fit')