ref: refactor eval loop to use hooks. use test_mode for if so we can split later (#3129)

* moved eval hooks

* moved eval hooks

* moved eval hooks

* moved eval hooks
This commit is contained in:
William Falcon 2020-08-24 15:48:14 -04:00 committed by GitHub
parent a0997bb7a6
commit c556ee6119
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 54 deletions

View File

@ -225,27 +225,9 @@ class TrainerEvaluationLoopMixin(ABC):
def reset_val_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def __call_eval_loop_hook_start(self, test_mode):
"""on_validation/test_epoch_start"""
self.__call_eval_loop_hook_evt(test_mode, 'start')
def __call_eval_loop_hook_end(self, test_mode):
"""on_validation/test_epoch_end"""
self.__call_eval_loop_hook_evt(test_mode, 'end')
def __call_eval_loop_hook_evt(self, test_mode, epoch_event):
model = self.get_model()
# on_[train/validation]_epoch_start hook
hook_root_name = 'test' if test_mode else 'validation'
hook_name = f'on_{hook_root_name}_epoch_{epoch_event}'
with self.profiler.profile(hook_name):
# call hook
getattr(self, hook_name)()
# model hooks
if self.is_function_implemented(hook_name):
getattr(model, hook_name)()
@abstractmethod
def call_hook(self, hook_name, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""
def _evaluate(
self,
@ -284,7 +266,10 @@ class TrainerEvaluationLoopMixin(ABC):
# --------------------------
# ON_EVAL_EPOCH_START hook
# --------------------------
self.__call_eval_loop_hook_start(test_mode)
if test_mode:
self.call_hook('on_test_epoch_start')
else:
self.call_hook('on_validation_epoch_start')
# run validation
for dataloader_idx, dataloader in enumerate(dataloaders):
@ -309,17 +294,10 @@ class TrainerEvaluationLoopMixin(ABC):
# callbacks
if test_mode:
self.on_test_batch_start(batch, batch_idx, dataloader_idx)
if self.is_overridden('on_test_batch_start'):
model_ref = self.get_model()
with self.profiler.profile('on_test_batch_start'):
model_ref.on_test_batch_start(batch, batch_idx, dataloader_idx)
self.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)
else:
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
if self.is_overridden('on_validation_batch_start'):
model_ref = self.get_model()
with self.profiler.profile('on_validation_batch_start'):
model_ref.on_validation_batch_start(batch, batch_idx, dataloader_idx)
self.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx)
# -----------------
# RUN EVALUATION STEP
# -----------------
@ -353,12 +331,11 @@ class TrainerEvaluationLoopMixin(ABC):
# EVAL STEP END
# ------------------
# on dp / ddp2 might still want to do something with the batch parts
eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end'
if self.is_overridden(eval_step_end_hook_name):
model_ref = self.get_model()
with self.profiler.profile(eval_step_end_hook_name):
eval_step_end = getattr(model_ref, eval_step_end_hook_name)
output = eval_step_end(output)
if self.is_overridden('test_step_end') or self.is_overridden('validation_step_end'):
if test_mode:
output = self.call_hook('test_step_end', output)
else:
output = self.call_hook('validation_step_end', output)
elif is_result_obj and (self.use_dp or self.use_ddp2):
# result auto reduce
@ -366,17 +343,9 @@ class TrainerEvaluationLoopMixin(ABC):
# callbacks (on __batch_end)
if test_mode:
self.on_test_batch_end(batch, batch_idx, dataloader_idx)
if self.is_overridden('on_test_batch_end'):
model_ref = self.get_model()
with self.profiler.profile('on_test_batch_end'):
model_ref.on_test_batch_end(batch, batch_idx, dataloader_idx)
self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx)
else:
self.on_validation_batch_end(batch, batch_idx, dataloader_idx)
if self.is_overridden('on_validation_batch_end'):
model_ref = self.get_model()
with self.profiler.profile('on_validation_batch_end'):
model_ref.on_validation_batch_end(batch, batch_idx, dataloader_idx)
self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx)
# track outputs for collation
if output is not None:
@ -416,7 +385,10 @@ class TrainerEvaluationLoopMixin(ABC):
# --------------------------
# ON_EVAL_EPOCH_END hook
# --------------------------
self.__call_eval_loop_hook_end(test_mode)
if test_mode:
self.call_hook('on_test_epoch_end')
else:
self.call_hook('on_validation_epoch_end')
return eval_results

View File

@ -1457,14 +1457,22 @@ class Trainer(
self._setup_amp_backend(amp_type)
def call_hook(self, hook_name, *args, **kwargs):
output = None
if self.is_overridden(hook_name):
model_ref = self.get_model()
with self.profiler.profile(hook_name):
# always profile hooks
with self.profiler.profile(hook_name):
# first call trainer hook
if hasattr(self, hook_name):
trainer_hook = getattr(self, hook_name)
trainer_hook(*args, **kwargs)
# next call hook in lightningModule
output = None
if self.is_overridden(hook_name):
model_ref = self.get_model()
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)
return output
return output
class _PatchDataLoader(object):