ref: clean up hooks in run_evaluation (#3156)
* clean up hooks in run_evaluation * clean up hooks in run_evaluation * clean up hooks in run_evaluation * clean up hooks in run_evaluation * clean up hooks in run_evaluation * clean up hooks in run_evaluation * clean up hooks in run_evaluation
This commit is contained in:
parent
22b9642117
commit
50aed42d6b
|
@ -13,6 +13,54 @@ class EvaluationLoop(object):
|
|||
self.predictions = None
|
||||
self.max_batches = None
|
||||
|
||||
def get_evaluation_dataloaders(self):
|
||||
# select dataloaders
|
||||
model = self.trainer.get_model()
|
||||
|
||||
if self.testing:
|
||||
self.trainer.reset_test_dataloader(model)
|
||||
dataloaders = self.trainer.test_dataloaders
|
||||
max_batches = self.trainer.num_test_batches
|
||||
else:
|
||||
if self.trainer.val_dataloaders is None:
|
||||
self.trainer.reset_val_dataloader(model)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
max_batches = self.trainer.num_val_batches
|
||||
|
||||
return dataloaders, max_batches
|
||||
|
||||
def should_skip_evaluation(self, dataloaders, max_batches):
|
||||
# skip when dataloaders aren't defined
|
||||
if dataloaders is None:
|
||||
return True
|
||||
|
||||
# enable disabling validation step with limit_val_batches = 0
|
||||
should_skip = sum(max_batches) == 0
|
||||
if should_skip:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def on_evaluation_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_start', *args, **kwargs)
|
||||
|
||||
def on_evaluation_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_end', *args, **kwargs)
|
||||
|
||||
def reload_evaluation_dataloaders(self):
|
||||
model = self.trainer.get_model()
|
||||
if self.testing:
|
||||
self.trainer.reset_test_dataloader(model)
|
||||
else:
|
||||
self.trainer.reset_val_dataloader(model)
|
||||
|
||||
def is_using_eval_results(self):
|
||||
outputs = self.outputs
|
||||
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
|
||||
|
|
|
@ -245,8 +245,6 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
entry is the number of batches to process in the corresponding dataloader.
|
||||
test_mode:
|
||||
"""
|
||||
# set up the loop for val/test
|
||||
self.evaluation_loop.testing = test_mode
|
||||
|
||||
# enable eval mode + no grads
|
||||
model.zero_grad()
|
||||
|
@ -310,7 +308,10 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
return eval_results
|
||||
|
||||
def run_evaluation(self, test_mode: bool = False):
|
||||
# hook
|
||||
# set up the loop for val/test
|
||||
self.evaluation_loop.testing = test_mode
|
||||
|
||||
# TODO: deprecate
|
||||
model = self.get_model()
|
||||
model.on_pre_performance_check()
|
||||
|
||||
|
@ -328,20 +329,13 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
dataloaders = self.val_dataloaders
|
||||
max_batches = self.num_val_batches
|
||||
|
||||
if dataloaders is None:
|
||||
return [], []
|
||||
|
||||
# Validation/Test begin callbacks
|
||||
if test_mode:
|
||||
self.on_test_start()
|
||||
else:
|
||||
self.on_validation_start()
|
||||
|
||||
# enable disabling validation step with limit_val_batches = 0
|
||||
should_skip = sum(max_batches) == 0
|
||||
if should_skip:
|
||||
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
|
||||
return [], []
|
||||
|
||||
# TODO: deprecate
|
||||
self.evaluation_loop.on_evaluation_start()
|
||||
|
||||
# run evaluation (val_step + val_step_end + val_epoch_end)
|
||||
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
|
||||
|
||||
|
@ -351,20 +345,12 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# hook
|
||||
model.on_post_performance_check()
|
||||
|
||||
# eventual dataset reloading
|
||||
if test_mode:
|
||||
if self.reload_dataloaders_every_epoch:
|
||||
self.reset_test_dataloader(model)
|
||||
else:
|
||||
# val
|
||||
if self.reload_dataloaders_every_epoch:
|
||||
self.reset_val_dataloader(model)
|
||||
# user may want to reload every epoch
|
||||
if self.reload_dataloaders_every_epoch:
|
||||
self.evaluation_loop.reload_evaluation_dataloaders()
|
||||
|
||||
# Validation/Test end callbacks
|
||||
if test_mode:
|
||||
self.on_test_end()
|
||||
else:
|
||||
self.on_validation_end()
|
||||
# TODO: deprecate
|
||||
self.evaluation_loop.on_evaluation_end()
|
||||
|
||||
return eval_loop_results, eval_results
|
||||
|
||||
|
|
Loading…
Reference in New Issue