ref: clean up hooks in run_evaluation (#3156)

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation

* clean up hooks in run_evaluation
This commit is contained in:
William Falcon 2020-08-25 10:56:32 -04:00 committed by GitHub
parent 22b9642117
commit 50aed42d6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 27 deletions

View File

@ -13,6 +13,54 @@ class EvaluationLoop(object):
self.predictions = None
self.max_batches = None
def get_evaluation_dataloaders(self):
# select dataloaders
model = self.trainer.get_model()
if self.testing:
self.trainer.reset_test_dataloader(model)
dataloaders = self.trainer.test_dataloaders
max_batches = self.trainer.num_test_batches
else:
if self.trainer.val_dataloaders is None:
self.trainer.reset_val_dataloader(model)
dataloaders = self.trainer.val_dataloaders
max_batches = self.trainer.num_val_batches
return dataloaders, max_batches
def should_skip_evaluation(self, dataloaders, max_batches):
# skip when dataloaders aren't defined
if dataloaders is None:
return True
# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
return True
return False
def on_evaluation_start(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_start', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_start', *args, **kwargs)
def on_evaluation_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_end', *args, **kwargs)
def reload_evaluation_dataloaders(self):
model = self.trainer.get_model()
if self.testing:
self.trainer.reset_test_dataloader(model)
else:
self.trainer.reset_val_dataloader(model)
def is_using_eval_results(self):
outputs = self.outputs
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)

View File

@ -245,8 +245,6 @@ class TrainerEvaluationLoopMixin(ABC):
entry is the number of batches to process in the corresponding dataloader.
test_mode:
"""
# set up the loop for val/test
self.evaluation_loop.testing = test_mode
# enable eval mode + no grads
model.zero_grad()
@ -310,7 +308,10 @@ class TrainerEvaluationLoopMixin(ABC):
return eval_results
def run_evaluation(self, test_mode: bool = False):
# hook
# set up the loop for val/test
self.evaluation_loop.testing = test_mode
# TODO: deprecate
model = self.get_model()
model.on_pre_performance_check()
@ -328,20 +329,13 @@ class TrainerEvaluationLoopMixin(ABC):
dataloaders = self.val_dataloaders
max_batches = self.num_val_batches
if dataloaders is None:
return [], []
# Validation/Test begin callbacks
if test_mode:
self.on_test_start()
else:
self.on_validation_start()
# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
return [], []
# TODO: deprecate
self.evaluation_loop.on_evaluation_start()
# run evaluation (val_step + val_step_end + val_epoch_end)
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
@ -351,20 +345,12 @@ class TrainerEvaluationLoopMixin(ABC):
# hook
model.on_post_performance_check()
# eventual dataset reloading
if test_mode:
if self.reload_dataloaders_every_epoch:
self.reset_test_dataloader(model)
else:
# val
if self.reload_dataloaders_every_epoch:
self.reset_val_dataloader(model)
# user may want to reload every epoch
if self.reload_dataloaders_every_epoch:
self.evaluation_loop.reload_evaluation_dataloaders()
# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
self.on_validation_end()
# TODO: deprecate
self.evaluation_loop.on_evaluation_end()
return eval_loop_results, eval_results