ref: .fit hook clean up (#3198)

* eval loop clean up

* eval loop clean up

* eval loop clean up

* eval loop clean up
This commit is contained in:
William Falcon 2020-08-26 13:53:23 -04:00 committed by GitHub
parent a1705441a9
commit e6bb26db1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 17 deletions

View File

@ -245,14 +245,11 @@ class TrainerEvaluationLoopMixin(ABC):
# hook
self.evaluation_loop.on_evaluation_start()
# ------------------------------
# ------------------------------
# ------------------------------
# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)
# hook
# TODO: needs to move inside the loop but breaks early stopping
# TODO: should this be insider the dataloader loop?
self.evaluation_loop.on_evaluation_epoch_start()
# run validation/testing
@ -300,10 +297,6 @@ class TrainerEvaluationLoopMixin(ABC):
# hook
self.evaluation_loop.on_evaluation_epoch_end()
# ------------------------------
# ------------------------------
# ------------------------------
# log the final eval loop metrics
eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode)

View File

@ -995,10 +995,8 @@ class Trainer(
# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)
# callbacks
self.on_fit_start(model)
if self.is_function_implemented('on_fit_start', model):
model.on_fit_start()
# hook
self.call_hook('on_fit_start', model)
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
@ -1095,12 +1093,10 @@ class Trainer(
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train(model)
# on fit end callback
self.on_fit_end()
if self.is_function_implemented('on_fit_end'):
model.on_fit_end()
# hook
self.call_hook('on_fit_end')
# teardown callback
# hook
self.teardown('fit')
if self.is_function_implemented('teardown'):
model.teardown('fit')